import random
from typing import Optional
from rdkit.Avalon.pyAvalonTools import GetAvalonCountFP
from rdkit.Chem import rdReducedGraphs
import hashlib
import numpy as np
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from descriptastorus.descriptors import rdNormalizedDescriptors
generator = rdNormalizedDescriptors.RDKit2DNormalized()
import torch
from rdkit.Chem import rdmolops


class MorganFingerprint:
    def __init__(self, shape: Optional[int] = 2048, radius: Optional[int] = 2):
        self.shape = shape
        self.radius = radius

    @staticmethod
    def canonicalize(smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            return Chem.MolToSmiles(mol, isomericSmiles=True)
        else:
            return smiles

    def smiles_to_morgan(self, smile: str) -> torch.Tensor:
        try:
            smile = self.canonicalize(smile)
            mol = Chem.MolFromSmiles(smile)
            features_vec = AllChem.GetMorganFingerprintAsBitVect(
                mol, self.radius, nBits=self.shape
            )
            features = np.zeros((1,))
            DataStructs.ConvertToNumpyArray(features_vec, features)
        except Exception as e:
            features = np.zeros((self.shape,))
        return torch.tensor(features, dtype=torch.float32)


def prepare_input_and_labels(tokenizer, input_sequences, max_length):
    outputs = {}
    inputs = tokenizer.batch_encode_plus(input_sequences, max_length=max_length, padding='max_length',
                                         return_tensors='pt', truncation=True)
    outputs['input_ids'] = inputs['input_ids']
    outputs['attention_mask'] = inputs['attention_mask']
    mask_input_ids = []
    labels = []
    for input_sequence in inputs['input_ids']:
        sequence, label = mask(input_sequence.tolist(), mask_id=tokenizer.vocab.get('[MASK]'))
        mask_input_ids.append(sequence)
        labels.append(label)
    outputs['labels'] = torch.tensor(labels)
    outputs['mask_input_ids'] = torch.tensor(mask_input_ids)
    return outputs


def prepare_input_and_labels_fingerprint(tokenizer, input_sequences, max_length):
    # descriptors and fingerprints
    descriptors = []
    erg = []
    avalon = []
    morgan =[]
    m = MorganFingerprint(shape=1024)
    for s in input_sequences:
        molecule = Chem.MolFromSmiles(s)
        descriptors.append(rdkit_2d_normalized_features(s))
        erg.append(torch.tensor(get_erg_fingerprints(molecule), dtype=torch.float32))
        avalon.append(torch.tensor(get_avalon_fingerprints(molecule), dtype=torch.float32))
        morgan.append(m.smiles_to_morgan(s))
    morgan = torch.stack(morgan)
    descriptors = torch.stack(descriptors)
    avalon = torch.stack(avalon)
    erg = torch.stack(erg)
    outputs = {}
    inputs = tokenizer.batch_encode_plus(input_sequences, max_length=max_length, padding='max_length',
                                         return_tensors='pt', truncation=True)
    outputs['input_ids'] = inputs['input_ids']
    outputs['attention_mask'] = inputs['attention_mask']
    mask_input_ids = []
    labels = []
    for input_sequence in inputs['input_ids']:
        sequence, label = mask(input_sequence.tolist(), mask_id=tokenizer.vocab.get('[MASK]'))
        mask_input_ids.append(sequence)
        labels.append(label)
    outputs['labels'] = torch.tensor(labels)
    outputs['mask_input_ids'] = torch.tensor(mask_input_ids)
    outputs['morgan'] = morgan
    outputs['avalon'] = avalon
    outputs['descriptors'] = descriptors
    outputs['erg'] = erg
    return outputs


def prepare_input_and_labels_belka(tokenizer, input_sequences,  max_length):
    outputs = {}
    inputs = tokenizer.batch_encode_plus(input_sequences, max_length=max_length, padding='max_length',
                                         return_tensors='pt', truncation=True)
    outputs['input_ids'] = inputs['input_ids']
    outputs['attention_mask'] = inputs['attention_mask']
    mask_input_ids = []
    labels = []
    for input_sequence in inputs['input_ids']:
        sequence, label = mask(input_sequence.tolist(), mask_id=tokenizer.vocab.get('[MASK]'))
        mask_input_ids.append(sequence)
        labels.append(label)
    outputs['labels'] = torch.tensor(labels)
    outputs['mask_input_ids'] = torch.tensor(mask_input_ids)
    outputs['sequence_3d_x'] = outputs['labels']
    outputs['sequence_3d_y'] = outputs['labels']
    outputs['sequence_3d_z'] = outputs['labels']
    outputs['atom_indices'] = outputs['labels']
    return outputs


def prepare_input_and_labels_morgan(tokenizer, input_sequences, max_length):
    outputs = {}
    batch_size = len(input_sequences)
    morgans = get_morgan(input_sequences)
    input_sequences_morgans = input_sequences + morgans
    inputs = tokenizer.batch_encode_plus(input_sequences_morgans, max_length=max_length, padding='max_length',
                                         return_tensors='pt', truncation=True)
    smiles_ids = inputs['input_ids'][:batch_size]
    smiles_ids = torch.where(smiles_ids == 0, -100, smiles_ids)
    morgan_ids = inputs['input_ids'][batch_size:]
    morgan_attention_mask = inputs['attention_mask'][batch_size:]
    outputs['labels'] = smiles_ids
    outputs['input_ids'] = morgan_ids
    outputs['attention_mask'] = morgan_attention_mask
    return outputs



def pad_sequences(sequences, max_length):
    # Determine the number of sequences
    num_sequences = len(sequences)

    # Initialize a tensor with zeros of shape (num_sequences, max_length)
    padded_sequences = torch.zeros((num_sequences, max_length), dtype=torch.long)

    # Iterate over the sequences
    for i, seq in enumerate(sequences):
        # Determine the length of the current sequence
        length = min(len(seq), max_length)

        # Copy the sequence into the tensor
        padded_sequences[i, :length] = torch.tensor(seq[:length])

    return padded_sequences


def prepare_input_and_labels_morgan_compact(tokenizer, input_sequences, max_length):
    outputs = {}
    morgans = get_stack_morgan(input_sequences)
    morgan_sequences = [x[0] for x in morgans]
    morgan_counts = [x[1] for x in morgans]
    morgan_counts = pad_sequences(morgan_counts, max_length)
    inputs = tokenizer.batch_encode_plus(morgan_sequences,
                                         max_length=max_length,
                                         padding='max_length',
                                         return_tensors='pt',
                                         truncation=True)
    mask_input_ids = []
    labels = []
    for input_sequence in inputs['input_ids']:
        sequence, label = mask(input_sequence.tolist(), mask_id=tokenizer.vocab.get('[MASK]'))
        mask_input_ids.append(sequence)
        labels.append(label)
    morgan_ids = inputs['input_ids']
    morgan_attention_mask = inputs['attention_mask']

    outputs['labels'] = labels
    outputs['input_ids'] = morgan_ids
    outputs['mask_input_ids'] = mask_input_ids
    outputs['attention_mask'] = morgan_attention_mask
    outputs['morgan_counts'] = morgan_counts
    return outputs


def get_morgan(input_sequences):
    m = MorganFingerprint()
    morgans = []
    for s in input_sequences:
        r = m.smiles_to_morgan(s)
        indices_of_ones = torch.nonzero(r == 1.0, as_tuple=False)
        indices_of_ones = indices_of_ones.squeeze(-1)
        indices_of_ones = indices_of_ones.tolist()
        s = ""
        for i in indices_of_ones:
            s += "[" + str(i)+"]"
        morgans.append(s)
    return morgans


def get_stack_morgan(input_sequences):
    ms = [MorganFingerprint(shape=4096, radius=radius) for radius in range(1, 6)]
    morgans = []
    for s in input_sequences:
        r = None
        for m in ms:
            if r is None:
                r = m.smiles_to_morgan(s)
            else:
                r += m.smiles_to_morgan(s)
        indices_of_nonzero = torch.nonzero(r != 0.0, as_tuple=False)
        indices_of_nonzero = indices_of_nonzero.squeeze(-1)
        c = r[indices_of_nonzero]
        norm = torch.norm(c, p=2)
        c = c/norm
        indices_of_nonzero = indices_of_nonzero.tolist()
        s = ""
        for i in indices_of_nonzero:
            s += "[" + str(i)+"]"
        morgans.append([s, c])
    return morgans


def get_atoms_from_smiles(smiles):
    """
    Iterates over a SMILES string, yielding tokens and offsets

    Parameters
    ----------
    smiles : iterable
        The SMILES string to iterate over

    Yields
    ------
    tuple(TokenType, str, int)
        A tuple describing the type of token and the associated data and offset in the smiles string
    """
    organic_subset = 'B C N O P S F Cl Br I * b c n o s p'.split()
    s = smiles
    smiles = iter(smiles)
    token = ''
    peek = None
    offset = -1
    atoms = []
    while True:
        if peek:
            char = peek
        else:
            char = next(smiles, '')
            offset += 1
        peek = None
        if not char:
            break
        if char == '[':
            token = char
            move = 0
            for char in smiles:
                move += 1
                token += char
                if char == ']':
                    break
            atoms.append(('ATOM', token, offset))
            offset += move
        elif char in organic_subset:
            peek = next(smiles, '')
            if char + peek in organic_subset:
                atoms.append(('ATOM', char + peek, offset))
                peek = None
            else:
                atoms.append(('ATOM', char, offset))
            offset += 1
        elif char in '-=#$:.':
            atoms.append(('BOND_TYPE', char, offset))
        elif char == '(':
            atoms.append(('BRANCH_START', '(', offset))
        elif char == ')':
            atoms.append(('BRANCH_END', ')', offset))
        elif char == '%':
            # If smiles is too short this will raise a ValueError, which is
            # (slightly) prettier than a StopIteration.
            atoms.append(('RING_NUM', int(next(smiles, '') + next(smiles, '')), offset+1))
            offset += 2
        elif char in '/\\':
            atoms.append(('EZSTEREO', char, offset))
        elif char.isdigit():
            atoms.append(('RING_NUM', int(char), offset))
    for _, a, offset in atoms:
        assert str(a) == s[offset: (offset+len(str(a)))]
    return atoms


def mask_supersmiles(sequence, mask_id, max_k=2, hash_size=1024):
    masked_sequence = []
    label_idx = []
    indices = []
    for i in range(len(sequence)):
        prob = random.random()
        if prob < 0.05:
            prob /= 0.1

            if sequence[i] not in [0, 12, 13]:
                indices.append(i)
                label_idx.append(1)
            else:
                label_idx.append(-100)

            if prob < 0.9:
                # 90% random change to mask token
                if sequence[i] not in [0, 12, 13]:
                    masked_sequence.append(mask_id)
                else:
                    masked_sequence.append(sequence[i])
            else:
                # 10% chance to keep current token
                masked_sequence.append(sequence[i])
        else:
            masked_sequence.append(sequence[i])
            label_idx.append(-100)
    masked_sequence = replace_elements(masked_sequence, indices, mask_id, max_k)
    labels = create_signature(sequence, max_k, hash_size, label_idx)
    return masked_sequence, labels, label_idx


def replace_elements(data_list, indices, value, k):
  # Create a copy of the data list to avoid modifying the original
  modified_list = data_list.copy()

  # Loop through the indices
  for index in indices:
    # Handle cases where the replacement would go out of bounds
    start_index = max(0, index - k)  # Ensure start_index is within list bounds
    end_index = min(len(data_list), index + k + 1)  # Ensure end_index doesn't exceed list length

    # Replace elements within the valid range
    modified_list[start_index:end_index] = [value] * (end_index - start_index)
  return modified_list


def mask(sequence, mask_id):
    masked_sequence = []
    labels = []
    for i in range(len(sequence)):
        prob = random.random()
        if prob < 0.15:
            prob /= 0.15

            if sequence[i] not in [0, 12, 13]:
                labels.append(sequence[i])
            else:
                labels.append(-100)
            if prob < 0.9:
                # 90% random change to mask token
                if sequence[i] not in [0, 12, 13]:
                    masked_sequence.append(mask_id)
                else:
                    masked_sequence.append(sequence[i])
            else:
                # 10% chance to keep current token
                masked_sequence.append(sequence[i])
        else:
            masked_sequence.append(sequence[i])
            labels.append(-100)
    c = 0
    for i in range(len(sequence)):
        if sequence[i] != -100:
            c += 1
    if c == 0:
        labels[1] = sequence[1]
    return masked_sequence, labels


def get_context_symbols(symbols, index, k):
    left_context = symbols[max(0, index - k):index]
    right_context = symbols[index + 1:min(len(symbols), index + k + 1)]
    return left_context, right_context


def concatenate_context(left_context, middle, right_context):
    c = left_context + middle + right_context
    c = sorted(c)
    c = str(c)
    return c


def hash_value(input_string):
    # Use SHA-256 to hash the string
    hash_object = hashlib.sha256(input_string.encode())
    # Get the hexadecimal digest
    hex_digest = hash_object.hexdigest()
    h = int(hex_digest, 16)
    return h


def create_signature(symbols, max_k, m, label_idx):
    binary_arrays = []
    for index in range(len(symbols)):
        binary_array = [0 for i in range(m)]
        if label_idx[index] != -100:
            for k in range(0, max_k+1):
                left_context, right_context = get_context_symbols(symbols, index, k)
                middle = [symbols[index]]
                concatenated_string = concatenate_context(left_context, middle, right_context)
                h = hash_value(concatenated_string)
                binary_array[h%m] = 1
        binary_arrays.append(binary_array)
    return np.array(binary_arrays)


def prepare_input_and_labels_supersmiles(tokenizer, input_sequences, sequences_3d, max_length, hash_size):
    outputs = {}
    inputs = tokenizer.batch_encode_plus(input_sequences, max_length=max_length, padding='max_length',
                                         return_tensors='pt', truncation=True)
    outputs['input_ids'] = inputs['input_ids']
    outputs['attention_mask'] = inputs['attention_mask']
    mask_input_ids = []
    labels = []
    label_indices = []
    for input_sequence in inputs['input_ids']:
        sequence, label, label_idx = mask_supersmiles(input_sequence.tolist(),
                                                      mask_id=tokenizer.vocab.get('[MASK]'),
                                                      hash_size=hash_size)
        mask_input_ids.append(sequence)
        labels.append(label)
        label_indices.append(label_idx)
    outputs['labels'] = torch.tensor(labels)
    outputs['mask_input_ids'] = torch.tensor(mask_input_ids)
    outputs['label_indices'] = torch.tensor(label_indices)
    return outputs


def align(x, y):
    i = 0
    j = 0
    aligns = {}
    while i < len(x) and j < len(y):
        if x[i] == 'H':
            i += 1
            continue
        a = x[i]
        b = y[j]
        if a.lower() in b.lower():
            aligns[i] = j
            i += 1
            j += 1
        else:
            return aligns
    return aligns


def get_graph(bonds, types):
    g = {}
    for b, r in zip(bonds, types):
        s = b[0]
        t = b[1]
        if s not in g:
            g[s] = [(r, t)]
        else:
            g[s].append((r, t))
    return g


def get_subgraph_fingerprint(g, idx, atoms, radius,  fingerprint_size):
    fingerprint = [0 for _ in range(fingerprint_size)]
    h = hash_value(atoms[idx])
    fingerprint[h % fingerprint_size] = 1
    neighbor = {(idx, h)}
    so_far = {idx}
    i = 0
    while len(neighbor) > 0 and i < radius:
        new_neighbors = set()
        for n in neighbor:
            h = n[1]
            if n[0] in g:
                for x in g[n[0]]:
                    nid = x[1]
                    r = x[0]
                    if nid not in new_neighbors and nid not in so_far:
                        s = str(h) + "_" + str(r) + "_" + str(atoms[nid])
                        sh = hash_value(s)
                        fingerprint[sh % fingerprint_size] = 1
                        new_neighbors.add((nid, sh))
                        so_far.add(nid)
        neighbor = new_neighbors
        i += 1
    return fingerprint, so_far


def prepare_input_and_labels_smilesgraph_old(atoms, tokenizer, input_sequences,
                                         bond_edges, bond_types, max_length,
                                         radius):
    outputs = {}
    inputs = tokenizer.batch_encode_plus(input_sequences, max_length=max_length, padding='max_length',
                                         return_tensors='pt', truncation=True)
    outputs['input_ids'] = inputs['input_ids']
    outputs['attention_mask'] = inputs['attention_mask']
    atom_indices = []
    for i, sequence in enumerate(input_sequences):
        atom_indices.append(torch.tensor(tokenizer.get_atom_indices(sequence))+1)
    labels = []
    label_indices = []
    for j in range(len(inputs.input_ids)):
        sa = tokenizer.decode(inputs.input_ids[j]).split()
        sa = [sa[i] for i in atom_indices[j] if i < max_length]
        xa = [a.split(".")[0] for a in atoms[j]]
        al = align(xa, sa)
        g = get_graph(bond_edges[j], bond_types[j])
        mask_indices = []
        fingerprint_indices = []
        fingerprints = []
        lal = list(al.keys())
        if len(lal) == 0:
            print("Empty alignment", xa, sa)
            fingerprints.append(torch.zeros(256))
            fingerprint_indices.append(0)
        else:
            for _ in range(1):
                r = random.randint(0, len(lal)-1)
                idx = lal[r]
                fingerprint, indices = get_subgraph_fingerprint(g, idx, xa, radius=radius, fingerprint_size=256)
                indices = [atom_indices[j][al[i]].item() for i in indices if i in al]
                mask_indices.extend(indices)
                fingerprints.append(fingerprint)
                fingerprint_indices.append(atom_indices[j][al[idx]].item())
        inputs.input_ids[j][mask_indices] = tokenizer.vocab.get('[MASK]')
        labels.append(fingerprints)
        label_indices.append(fingerprint_indices)
    outputs['labels'] = torch.tensor(labels)
    outputs['label_indices'] = torch.tensor(label_indices)
    return outputs


def prepare_input_and_labels_smilesgraph(tokenizer, input_sequences, max_length, radius, hash_size, num_sm=1):
    outputs = {}
    inputs = tokenizer.batch_encode(input_sequences, max_len=max_length)
    outputs['input_ids'] = inputs['input_ids']
    outputs['attention_mask'] = inputs['attention_mask']
    atoms = inputs['atoms']
    bond_edges = inputs['bond_edges']
    bond_types = inputs['bond_types']
    alignments = inputs['alignments']
    labels = []
    label_indices = []
    singleton_labels = []
    singleton_label_indices = []
    for j in range(len(inputs['input_ids'])):
        xa = atoms[j]
        al = alignments[j]  # a map from xa indices to input_ids indices
        g = get_graph(bond_edges[j], bond_types[j])
        mask_indices = []
        fingerprint_indices = []
        fingerprints = []
        singleton_fingerprint_indices = []
        singleton_fingerprints = []
        for _ in range(1):
            idx = random.randint(0, len(al)-1)
            if al[idx] < max_length:
                fingerprint, indices = get_subgraph_fingerprint(g, idx, xa, radius=radius, fingerprint_size=hash_size)
                indices = [al[i] for i in indices if al[i] < max_length]
                mask_indices.extend(indices)
                fingerprints.append(fingerprint)
                fingerprint_indices.append(al[idx])
        # mask singleton
        for _ in range(3):
            idx = random.randint(0, outputs["attention_mask"][j].sum()-1)
            fingerprint = [0 for _ in range(hash_size)]
            h = hash_value(str(outputs['input_ids'][j][idx].item()))
            fingerprint[h % hash_size] = 1
            outputs['input_ids'][j][idx] = tokenizer.vocab['[MASK]']
            singleton_fingerprints.append(fingerprint)
            singleton_fingerprint_indices.append(idx)
            mask_indices.append(idx)
        outputs['input_ids'][j][mask_indices] = tokenizer.vocab['[MASK]']
        labels.append(fingerprints)
        label_indices.append(fingerprint_indices)
        singleton_labels.append(singleton_fingerprints)
        singleton_label_indices.append(singleton_fingerprint_indices)
    outputs['labels'] = torch.tensor(labels)
    outputs['label_indices'] = torch.tensor(label_indices)
    outputs['singleton_labels'] = torch.tensor(singleton_labels)
    outputs['singleton_label_indices'] = torch.tensor(singleton_label_indices)
    return outputs


def prepare_input_and_labels_logo(input_sequences, radius, hash_size, num_sm=1):
    outputs = {}
    graphs = smiles_to_molecular_graph(input_sequences)
    batch_fingerprint_indices = []
    batch_fingerprints = []
    ptr = [0]
    node_ids = []
    edge_index = []
    for graph in graphs:
        fingerprint_indices = []
        fingerprints = []
        mask_indices = []
        g = get_graph(graph['edge_index'], graph['edge_types'])
        xa = [str(x) for x in graph['atomic_numbers']]
        for _ in range(num_sm):
            idx = random.randint(0, len(xa) - 1)
            fingerprint, indices = get_subgraph_fingerprint(g, idx, xa, radius=radius, fingerprint_size=hash_size)
            fingerprint_indices.append(idx)
            fingerprints.append(fingerprint)
            mask_indices.extend(indices)

        ptr.append(ptr[-1]+len(xa))
        node_ids.extend(graph['atomic_numbers'])
        batch_fingerprint_indices.append(fingerprint_indices)
        batch_fingerprints.append(fingerprints)
    return outputs


def prepare_input_and_labels_sm(tokenizer, input_sequences, max_length, radius, hash_size, num_sm=1):
    outputs = {}
    inputs = tokenizer.batch_encode(input_sequences, max_len=max_length)
    outputs['input_ids'] = inputs['input_ids']
    outputs['attention_mask'] = inputs['attention_mask']
    atoms = inputs['atoms']
    bond_edges = inputs['bond_edges']
    bond_types = inputs['bond_types']
    alignments = inputs['alignments']
    labels = []
    label_indices = []
    singleton_labels = []
    singleton_label_indices = []
    for j in range(len(inputs['input_ids'])):
        xa = atoms[j]
        al = alignments[j]  # a map from xa indices to input_ids indices
        g = get_graph(bond_edges[j], bond_types[j])
        mask_indices_sm = []
        mask_indices = []
        fingerprint_indices = []
        fingerprints = []
        singleton_fingerprint_indices = []
        singleton_fingerprints = []
        for _ in range(num_sm):
            idx = random.randint(0, len(al)-1)
            if al[idx] < max_length:
                fingerprint, indices = get_subgraph_fingerprint(g, idx, xa, radius=radius, fingerprint_size=hash_size)
                indices = [al[i] for i in indices if al[i] < max_length]
                mask_indices_sm.extend(indices)
                fingerprints.append(fingerprint)
                fingerprint_indices.append(al[idx])
        # mask singleton
        for _ in range(3):
            idx = random.randint(0, outputs["attention_mask"][j].sum()-1)
            fingerprint = [0 for _ in range(hash_size)]
            h = hash_value(str(outputs['input_ids'][j][idx].item()))
            fingerprint[h % hash_size] = 1
            outputs['input_ids'][j][idx] = tokenizer.vocab['[MASK]']
            singleton_fingerprints.append(fingerprint)
            singleton_fingerprint_indices.append(idx)
            mask_indices.append(idx)
        outputs['input_ids'][j][mask_indices] = tokenizer.vocab['[MASK]']
        labels.append(fingerprints)
        label_indices.append(fingerprint_indices)
        singleton_labels.append(singleton_fingerprints)
        singleton_label_indices.append(singleton_fingerprint_indices)
    outputs['labels'] = torch.tensor(labels)
    outputs['label_indices'] = torch.tensor(label_indices)
    outputs['singleton_labels'] = torch.tensor(singleton_labels)
    outputs['singleton_label_indices'] = torch.tensor(singleton_label_indices)
    return outputs


def pca(X, k):
    """
    Perform PCA on the matrix X and reduce its dimensionality to k along the second dimension.

    Args:
    X (torch.Tensor): The input data matrix of shape (m, n)
    k (int): The number of dimensions to reduce to

    Returns:
    torch.Tensor: The reduced dimension matrix of shape (m, k)
    """
    # Center the data
    X_mean = torch.mean(X, dim=0)
    X_centered = X - X_mean

    # Compute covariance matrix
    cov_matrix = torch.mm(X_centered.t(), X_centered) / (X_centered.shape[0] - 1)

    # Perform eigenvalue decomposition
    eigenvalues, eigenvectors = torch.linalg.eig(cov_matrix)

    # Sort eigenvalues and eigenvectors
    eigenvalues = eigenvalues.real
    eigenvectors = eigenvectors.real
    sorted_indices = torch.argsort(eigenvalues, descending=True)
    sorted_eigenvectors = eigenvectors[:, sorted_indices]

    # Select top k eigenvectors
    top_k_eigenvectors = sorted_eigenvectors[:, :k]

    # Transform the data
    X_reduced = torch.mm(X_centered, top_k_eigenvectors)

    return X_reduced


def pca_twice(A, k, l):
    """
    Perform PCA twice: once to reduce the number of rows to k and then to reduce the number of columns to l.

    Args:
    A (torch.Tensor): The input data matrix of shape (m, n)
    k (int): The number of rows to reduce to
    l (int): The number of columns to reduce to

    Returns:
    torch.Tensor: The reduced dimension matrix of shape (k, l)
    """
    # First PCA: Reduce rows
    A_reduced_rows = pca(A, k)

    # Second PCA: Reduce columns
    A_reduced_cols = pca(A_reduced_rows.t(), l).t()

    return A_reduced_cols


def count_to_array(fingerprint):
    array = np.zeros((0,), dtype=np.int8)

    DataStructs.ConvertToNumpyArray(fingerprint, array)

    return array


def get_avalon_fingerprints(molecules, n_bits=1024):
    fingerprints = GetAvalonCountFP(molecules, nBits=n_bits)
    fingerprints = count_to_array(fingerprints)
    return fingerprints


def get_erg_fingerprints(molecules):
    fingerprints = rdReducedGraphs.GetErGFingerprint(molecules)
    return fingerprints


def rdkit_2d_normalized_features(smiles: str):
    results = generator.process(smiles)
    processed, features = results[0], results[1:]
    features = torch.tensor(features, dtype=torch.float32)
    features[torch.isnan(features)] = 0.0
    return features


def smiles_to_molecular_graph(smiles_list):
    molecular_graphs = []

    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            continue

        # Get atomic numbers (instead of atom names)
        atomic_numbers = [atom.GetAtomicNum() for atom in mol.GetAtoms()]

        # Get edge index (pairs of connected atoms)
        edge_index = []
        edge_types = []

        bonds = mol.GetBonds()
        for bond in bonds:
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_index.append([i, j])
            edge_index.append([j, i])  # since graph is undirected

            bond_type = bond.GetBondType()
            if bond_type == Chem.rdchem.BondType.SINGLE:
                bond_value = 1
            elif bond_type == Chem.rdchem.BondType.DOUBLE:
                bond_value = 2
            elif bond_type == Chem.rdchem.BondType.TRIPLE:
                bond_value = 3
            elif bond_type == Chem.rdchem.BondType.AROMATIC:
                bond_value = 4
            else:
                bond_value = 0  # in case there are undefined bond types

            edge_types.append(bond_value)
            edge_types.append(bond_value)  # for both directions

        # Append the molecular graph data
        molecular_graphs.append({
            "atomic_numbers": atomic_numbers,
            "edge_index": edge_index,
            "edge_types": edge_types
        })

    return molecular_graphs
