import os
import random
import numpy as np
import pandas as pd
import deepchem as dc

from sklearn.model_selection import train_test_split
from scipy.sparse import csgraph
from scipy.linalg import eigh
from sklearn.cluster import SpectralClustering
from kneed import KneeLocator

from tqdm import tqdm
from collections import Counter

import torch
from torch_geometric.data import Data, Batch

import numpy as np
from rdkit import Chem
from rdkit.Chem import rdFMCS


from joblib import Parallel, delayed


# Allowable features for atoms and bonds.
allowable_features = {
    'possible_atomic_num_list': list(range(1, 119)),
    'possible_chirality_list': [
        Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
        Chem.rdchem.ChiralType.CHI_OTHER
    ],
    'possible_bonds': [
        Chem.rdchem.BondType.SINGLE,
        Chem.rdchem.BondType.DOUBLE,
        Chem.rdchem.BondType.TRIPLE,
        Chem.rdchem.BondType.AROMATIC
    ],
    'possible_bond_dirs': [
        Chem.rdchem.BondDir.NONE,
        Chem.rdchem.BondDir.ENDUPRIGHT,
        Chem.rdchem.BondDir.ENDDOWNRIGHT
    ]
}

def mol_to_graph(mol):
    # process atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        atomic_num = atom.GetAtomicNum()
        chirality = atom.GetChiralTag()
        try:
            atomic_num_idx = allowable_features['possible_atomic_num_list'].index(atomic_num)
        except ValueError:
            atomic_num_idx = 0
        try:
            chirality_idx = allowable_features['possible_chirality_list'].index(chirality)
        except ValueError:
            chirality_idx = 0
        atom_features_list.append([atomic_num_idx, chirality_idx])
    x = torch.tensor(np.array(atom_features_list), dtype=torch.long)
    
    # Process bonds.
    if mol.GetNumBonds() > 0:
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            bond_type = bond.GetBondType()
            try:
                bond_type_idx = allowable_features['possible_bonds'].index(bond_type)
            except ValueError:
                bond_type_idx = 0
            bond_dir = bond.GetBondDir()
            try:
                bond_dir_idx = allowable_features['possible_bond_dirs'].index(bond_dir)
            except ValueError:
                bond_dir_idx = 0
            edges_list.append((i, j))
            edge_features_list.append([bond_type_idx, bond_dir_idx])
            edges_list.append((j, i))
            edge_features_list.append([bond_type_idx, bond_dir_idx])
        edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)
        edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, 2), dtype=torch.long)
    
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)


def build_dataset(graph_data_list, valid_reps, valid_smiles, valid_targets, indices):
    dataset = {
                'reps': [valid_reps[i] for i  in indices],
                'graphs': [graph_data_list[i] for i  in indices],
                'smiles': [valid_smiles[i] for i in indices],
                'targets': [valid_targets[i] for i in indices]
    }   
    return dataset


def save_dataset_as_csv(args, dataset):
    dir_path = os.path.join(args.data_dir, args.dataset_name, args.split_x_type, args.property)
    os.makedirs(dir_path, exist_ok=True)
    for split in ['train', 'eval', 'ood']:
        df = pd.DataFrame({
                            'smiles': dataset[split]['smiles'], 
                            'target': np.array(dataset[split]['targets']).reshape(-1)
        })
        # data_samples
        print(f'\n{split} data samples')
        for col in df.columns:
            print(col, df[col][0])
        
        df.to_csv(os.path.join(dir_path, f'{split}_featurized.csv'), index=False)



def organize_dataframe(uniques, smiles):
    smiles_by_scaffold = []
    unique_scaffold_smiles = []
    n_mols = []
    for scaffold, idx_list in uniques.items():
        unique_scaffold_smiles.append(scaffold)
        smiles_by_scaffold.append(';'.join([smiles[i] for i in idx_list]))
        n_mols.append(len(idx_list))
    
    df = pd.DataFrame({'scaffolds': unique_scaffold_smiles,
                       'original_smiles': smiles_by_scaffold,
                       'n_mols': n_mols})
    return df


def group_and_sort(clusters, similarity_matrix, n_smiles_for_each_scaffold):
    """ Group clusters and sort by mean similarity. """
    clust_sims = cluster_similarity(similarity_matrix, clusters)
    mean_sims = mean_cluster_sim_to_all_clusters(clust_sims)
    # Calculate total size per cluster (across scaffolds)
    cluster_size = [sum(np.array(n_smiles_for_each_scaffold)[np.argwhere(np.array(clusters) == c).flatten()]) 
                    for c in set(clusters)]
    df = pd.DataFrame({
        'cluster': list(range(len(set(clusters)))),
        'size_scaffolds': sorted(Counter(clusters).values()),
        'size_full': cluster_size,
        'mean_sim': mean_sims
    })
    return df.sort_values(by=['mean_sim'])

def cluster_similarity(X: np.ndarray, clusters: np.ndarray) -> np.ndarray:
    """ Compute mean similarity between clusters. """
    n_clusters = len(set(clusters))
    clust_molidx = {c: np.where(clusters == c)[0] for c in set(clusters)}
    clust_sims = np.zeros((n_clusters, n_clusters))
    for i in range(n_clusters):
        for j in range(i, n_clusters):
            row_idx, col_idx = clust_molidx[i], clust_molidx[j]
            clust_sim_matrix = X[row_idx][:, col_idx]
            clust_sims[i, j] = np.mean(clust_sim_matrix)
    return clust_sims + clust_sims.T - np.diag(np.diag(clust_sims))

def eigenvalue_cluster_approx(S: np.ndarray) -> int:
    """ Estimate number of clusters using the elbow in the Laplacian’s eigenvalues. """
    laplacian = csgraph.laplacian(S, normed=True)
    eigenvalues, _ = eigh(laplacian)
    kn = KneeLocator(range(len(eigenvalues)), eigenvalues,
                     curve='concave', direction='increasing',
                     interp_method='interp1d')
    return int(kn.knee) if kn.knee is not None else 2

def mean_cluster_sim_to_all_clusters(sim_matrix: np.ndarray) -> list:
    """ Compute mean similarity of each cluster to all other clusters. """
    n_clusters = sim_matrix.shape[0]
    mean_sim = []
    for i in range(n_clusters):
        mask = [j for j in range(n_clusters) if j != i]
        mean_sim.append(np.mean(sim_matrix[i][mask]))
    return mean_sim

def select_ood_clusters(df, size_cutoff):
    """ Select clusters until a cumulative size cutoff is reached. """
    selected_clusters = []
    cumulative = 0
    for clust, size in zip(df['cluster'], df['size_full']):
        if cumulative + size < size_cutoff:
            selected_clusters.append(clust)
            cumulative += size
    return selected_clusters

def split_data(args, clusters, ood_clusters):
    ood_indices = [i for i, cluster in enumerate(clusters) if cluster in ood_clusters]
    train_eval_indices = [i for i, cluster in enumerate(clusters) if cluster not in ood_clusters]
    train_indices, eval_indices = train_test_split(train_eval_indices, test_size = args.split_ratio, shuffle=True, random_state=args.seed)
    train_indices = np.array(train_indices)
    eval_indices = np.array(eval_indices)
    ood_indices = np.array(ood_indices)
    return train_indices, eval_indices, ood_indices


from rdkit.Chem import rdRascalMCES
class MCES:
    def __init__(self, args, method='spectral_clustering'):
        self.args = args
        self.method = method
        
    def get_mces_ood(self, mols_list):
        sim_matrix = self.compute_similarity_matrix(mols_list)
        if self.method == 'spectral_clustering':
            train_eval_indices, ood_indices = self.spectral_clustering_split(sim_matrix)
        elif self.method == 'greedy': 
            train_eval_indices, ood_indices = self.greedy_mces_split(sim_matrix)
        train_indices, eval_indices = train_test_split(train_eval_indices, test_size = self.args.split_ratio, shuffle=True, random_state=self.args.seed)
        train_indices = np.array(train_indices)
        eval_indices = np.array(eval_indices)
        ood_indices = np.array(ood_indices)
        return train_indices, eval_indices, ood_indices
        
    def calculate_mces_similarity(self, mol1, mol2):
        mcs_result = rdFMCS.FindMCS([mol1, mol2],
                                completeRingsOnly=True,
                                ringMatchesRingOnly=True,
                                timeout=2)
        if mcs_result.canceled:
            return 0.0
        num_bonds = mcs_result.numBonds
        if num_bonds == 0:
            return 0.0
        max_bonds = max(mol1.GetNumBonds(), mol2.GetNumBonds())
        return num_bonds / max_bonds
    
    def score_difference(self, mol1, mol2):
        opts = rdRascalMCES.RascalOptions()
        opts.returnEmptyMCES = True
        mcs = rdRascalMCES.FindMCES(mol1,mol2,opts)
        return mcs[0].Similarity

    def compute_similarity_matrix(self, mols):
        n = len(mols)

        def calculate_similarity(i, j):
            return i, j, self.score_difference(mols[i], mols[j])

        pairs = [(i, j) for i in range(n) for j in range(i+1, n)]
        
        num_cores = 60
        results = Parallel(n_jobs=num_cores)(
            delayed(calculate_similarity)(i, j) for i, j in tqdm(pairs, desc='calculating mces similarity...')
        )
        
        sim_matrix = np.zeros((n, n))
        for i, j, sim in results:
            sim_matrix[i, j] = sim
            sim_matrix[j, i] = sim
        return sim_matrix
    
    def greedy_mces_split(self, sim_matrix, id_ratio=0.95):
        n = sim_matrix.shape[0]
        all_indices = set(range(n))
        id_indices = set()
        ood_indices = set()

        start = random.choice(list(all_indices))
        id_indices.add(start)
        all_indices.remove(start)

        while len(id_indices) < int(id_ratio * n):
            best_candidate = None
            best_score = -1
            for idx in all_indices:
                score = np.mean([sim_matrix[idx, id_idx] for id_idx in id_indices])
                if score > best_score:
                    best_score = score
                    best_candidate = idx
            id_indices.add(best_candidate)
            all_indices.remove(best_candidate)

        ood_indices = all_indices

        return list(id_indices), list(ood_indices)

    def spectral_clustering_split(self, sim_matrix, n_clusters=10, id_ratio=0.95):
        distance_matrix = 1 - sim_matrix
        affinity_matrix = np.exp(-distance_matrix)

        clustering = SpectralClustering(
            n_clusters=n_clusters,
            affinity='precomputed',
            assign_labels='kmeans',
            random_state=42
        )
        labels = clustering.fit_predict(affinity_matrix)

        clusters = {}
        for idx, label in enumerate(labels):
            clusters.setdefault(label, []).append(idx)

        cluster_ids = list(clusters.keys())
        np.random.shuffle(cluster_ids)

        n_id_clusters = int(id_ratio * n_clusters)
        id_clusters = cluster_ids[:n_id_clusters]
        ood_clusters = cluster_ids[n_id_clusters:]

        id_indices = [idx for label in id_clusters for idx in clusters[label]]
        ood_indices = [idx for label in ood_clusters for idx in clusters[label]]

        return id_indices, ood_indices



def load_pretrained_gnn(path=None):
    import sys; sys.path.append('..')
    from baselines.pretrained_gnns.model import GNN_graphpred
    gnn_num_layer =  5
    gnn_emb_dim = 300
    gnn_JK = "last"
    gnn_drop_ratio = 0.5
    gnn_graph_pooling = "mean"
    gnn_type = 'gin'
    ac_size = 1
    pretrained_model_path = '../baselines/pretrained_gnns/model_gin/supervised_contextpred.pth' if path==None else path
    # Instantiate the base GNN encoder
    pretrained_encoder = GNN_graphpred(
        num_layer=gnn_num_layer,
        emb_dim=gnn_emb_dim,
        num_tasks=ac_size, # Use ac_size for num_tasks
        JK=gnn_JK,
        drop_ratio=gnn_drop_ratio,
        graph_pooling=gnn_graph_pooling,
        gnn_type=gnn_type
    )

    # Load pretrained weights
    if not pretrained_model_path:
            raise ValueError("encoder_path (pretrained_model_path) is required for multi_anchor_gnn")

    print(f"------------------------Loading pretrained encoder from: {pretrained_model_path}------------------------")
    try:
        # Attempt loading using a potential 'from_pretrained' method
        if hasattr(pretrained_encoder, 'from_pretrained'):
                pretrained_encoder.from_pretrained(pretrained_model_path)
                print("Loaded weights using 'from_pretrained' method.")
        else:
                # Fallback to loading state_dict directly
                print("No 'from_pretrained' method found, attempting load_state_dict.")
                state_dict = torch.load(pretrained_model_path, map_location='cpu')
                # Handle nested state dicts
                if 'model_state_dict' in state_dict: state_dict = state_dict['model_state_dict']
                elif 'gnn_model_state_dict' in state_dict: state_dict = state_dict['gnn_model_state_dict']
                # Clean keys ('module.' prefix)
                state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
                pretrained_encoder.load_state_dict(state_dict, strict=False) # Use strict=False for flexibility
                print("Loaded weights using 'load_state_dict'.")

    except Exception as load_e:
            print(f"ERROR loading pretrained weights into GNN_graphpred: {load_e}")
            print("Check path and model compatibility.")
            raise # Re-raise error if loading fails
    return pretrained_encoder

def load_pretrained_smi_ted():
    import sys; sys.path.append('..')
    from baselines.smi_ted.smi_ted_light.load import load_smi_ted

    encoder_folder      = "../baselines/smi_ted/smi_ted_light"
    encoder_ckpt        = "smi_ted_Light_40.pt"

    pretrained_encoder = load_smi_ted(folder=encoder_folder, ckpt_filename=encoder_ckpt)
    return pretrained_encoder



def get_gnn_embedding(graph_list, device='cpu'):
    pretrained_encoder = load_pretrained_gnn()
    pretrained_encoder.to(device)
    batch = Batch.from_data_list(graph_list)
    batch.to(device)
    x = pretrained_encoder.gnn(batch)
    gnn_embs = pretrained_encoder.pool(x, batch.batch)
    
    return gnn_embs.detach().cpu().numpy()


def get_smi_ted_embedding(smiles_list, device='cpu'):
    pretrained_encoder = load_pretrained_smi_ted()
    pretrained_encoder.to(device)
    
    all_embs = []
    for smiles in smiles_list:
        with torch.no_grad():
            embs = pretrained_encoder.extract_embeddings(smiles)
            all_embs.append(embs.cpu().numpy())

    return np.vstack(all_embs)