import os
import os.path as osp
import torch
import numpy as np
from rdkit import Chem
from rdkit.Chem import AllChem
from tqdm import tqdm
import pickle
import rdkit
from rdkit.Chem import Descriptors
from rdkit.Chem import MolFromSmiles, MolToSmiles, rdmolops, QED
import networkx as nx
import sascorer

from utils.property_models import EGNN, Naive, NumNodes


def construct_complete_graph(num_node, return_index=True, add_self_loop=False):
    '''
    num_node: number of nodes in the graph
    add_self_loop: whether add self loop in the graph
    '''
    import torch
    import torch_geometric as tgeom

    adj = 1 - torch.eye(num_node)
    if add_self_loop:
        adj += torch.eye(num_node)

    if not return_index:
        return adj
    else:
        edge_index, _ = tgeom.utils.dense_to_sparse(adj)
        return adj, edge_index
    

def get_model(args):
    if args.model_name == 'egnn':
        model = EGNN(in_node_nf=5, in_edge_nf=0, hidden_nf=args.nf, device=args.device, n_layers=args.n_layers,
                     coords_weight=1.0,
                     attention=args.attention, node_attr=args.node_attr)
    elif args.model_name == 'naive':
        model = Naive(device=args.device)
    elif args.model_name == 'numnodes':
        model = NumNodes(device=args.device)
    else:
        raise Exception("Wrong model name %s" % args.model_name)


    return model


model = None
def evaluate(mol_list, condition_list, condition = 'alpha', return_pred = False):
    global model
    qm9_atom_list = ['H', 'C', 'O', 'N', 'F']
    atom_encoder = {'H': 0, 'C': 1, 'N': 2, 'O': 3, 'F': 4}

    def get_classifier(dir_path='', device='cpu'):
        print('Loading classifier from %s' % dir_path)
        with open(osp.join(dir_path, 'args.pickle'), 'rb') as f:
            args_classifier = pickle.load(f)
        args_classifier.device = device
        args_classifier.model_name = 'egnn'
        classifier = get_model(args_classifier)
        classifier_state_dict = torch.load(osp.join(dir_path, 'best_checkpoint.npy'), map_location=torch.device('cpu'))
        classifier.load_state_dict(classifier_state_dict)
        return classifier
    
    if model is None:
        model = get_classifier('../LDM_property_prediction/checkpoints/QM9/Property_Classifiers/exp_class_' + condition)

    cond_value_dict = {"mu": [1.9160064458847046, -0.22752515971660614, 2.6750875, 11.757326126098633, 19.160064322765635, -2.2752515523038177],
                       "alpha": [0.9102288484573364, -1.0507861375808716, 75.37342, 62.727723121643066, 9.102288057650055, -10.507861181150194],
                       "gap": [0.947828471660614, -0.5386840105056763, 0.25221726, 0.390242263674736, 9.478284826116898, -5.386839999595246],
                       "homo": [0.7526819705963135, -1.1657202243804932, -0.24028876, 0.1615406945347786, 7.526819599545699, -11.657201893308272],
                       "lumo": [0.47794172167778015, -0.4841439723968506, 0.011928127, 0.37990380078554153, 4.77941704545956, -4.841439664678198],
                       "cv": [0.4776511788368225, -0.7890406250953674, 31.620028, 32.11751937866211, 4.776511788526875, -7.890406281196496]}
    cond_max, cond_min, cond_mean, cond_mad, pred_max, pred_min = cond_value_dict[condition]

    pred_list = []
    label_list = []
    no_emb_cnt = 0
    for mol, cond in zip(tqdm(mol_list), condition_list):
        # filter out molecules with atoms not in qm9
        atom_list = []
        for atom in mol.GetAtoms():
            atom_list.append(atom.GetSymbol())
        if len(atom_list) > 29:
            continue
        if len(set(atom_list).difference(qm9_atom_list)) > 0:
            continue

        nodes = torch.cat([ torch.nn.functional.one_hot(torch.tensor(atom_encoder[atom],
            dtype=torch.int64), num_classes=5).unsqueeze(dim=0) for atom in atom_list ], dim=0).float()

        mol = Chem.AddHs(mol)
        emb_res = AllChem.EmbedMolecule(mol, maxAttempts = 10000)
        if emb_res == -1:
            no_emb_cnt += 1
            continue
        atom_positions = torch.tensor(mol.GetConformer().GetPositions(), dtype=torch.float32)

        _, edge_index = construct_complete_graph(len(atom_list), return_index=True, add_self_loop=False)
        edges = [edge_index[0], edge_index[1]]

        atom_mask = torch.ones((len(atom_list), 1))
        edge_mask = torch.ones((edge_index.shape[1], 1))

        n_nodes = len(atom_list)

        with torch.no_grad():
            pred = model(h0=nodes, x=atom_positions, edges=edges, edge_attr=None, node_mask=atom_mask, edge_mask=edge_mask, n_nodes=n_nodes)
        pred[pred>pred_max] = pred_max
        pred[pred<pred_min] = pred_min

        pred_list.append(pred.item() * cond_mad / 10 + cond_mean)
        label_list.append(cond)

    if return_pred:
        return np.abs((np.array(pred_list) - np.array(label_list))).mean(), pred_list
    return np.abs((np.array(pred_list) - np.array(label_list))).mean(), no_emb_cnt


def calculate_qed(smiles):
    mol = Chem.MolFromSmiles(smiles)
    assert mol is not None
    return QED.default(mol)

def calculate_plogp(smiles):
    SA_scores_mean = -3.0532705350334903
    SA_scores_std = 0.8348193592477144
    logP_values_mean = 2.4571235825155546
    logP_values_std = 1.4343304662156684
    cycle_scores_mean = -0.0483411904303765
    cycle_scores_std = 0.28760425031637704

    smiles_rdkit = MolToSmiles(MolFromSmiles(smiles), isomericSmiles = True)
    logP_values = Descriptors.MolLogP(MolFromSmiles(smiles_rdkit))
    SA_scores = -sascorer.calculateScore(MolFromSmiles(smiles_rdkit))

    cycle_list = nx.cycle_basis(nx.Graph(rdmolops.GetAdjacencyMatrix(MolFromSmiles(smiles_rdkit))))
    if len(cycle_list) == 0:
        cycle_length = 0
    else:
        cycle_length = max([ len(j) for j in cycle_list ])
    if cycle_length <= 6:
        cycle_length = 0
    else:
        cycle_length = cycle_length - 6
    cycle_scores = -cycle_length

    SA_scores_normalized = (np.array(SA_scores) - SA_scores_mean) / SA_scores_std
    logP_values_normalized = (np.array(logP_values) - logP_values_mean) / logP_values_std
    cycle_scores_normalized = (np.array(cycle_scores) - cycle_scores_mean) / cycle_scores_std

    targets = SA_scores_normalized + logP_values_normalized + cycle_scores_normalized
    return targets

def evaluate_zinc(smile_list, y_list, condition = "plogp"):    
    # target_data = torch.load("../data/zinc/target_data.pth")
    # SA_scores_mean = target_data["SA_scores_mean"]
    # SA_scores_std = target_data["SA_scores_std"]
    # logP_values_mean = target_data["logP_values_mean"]
    # logP_values_std = target_data["logP_values_std"]
    # cycle_scores_mean = target_data["cycle_scores_mean"]
    # cycle_scores_std = target_data["cycle_scores_std"]

    if condition == "plogp":
        targets = np.array([calculate_plogp(smile) for smile in smile_list])
    elif condition == "qed":
        targets = np.array([calculate_qed(smile) for smile in smile_list])

    mae = np.abs((targets - np.array(y_list))).mean()
    top3_idx = np.argsort(targets)[::-1][:3]
    top_3_target = [targets[i] for i in top3_idx]
    top_3_smiles = [smile_list[i] for i in top3_idx]
    return mae, {"target": top_3_target, "smiles": top_3_smiles}

