import deepchem as dc
from rdkit import Chem
from rdkit.Chem import AllChem, rdDepictor
import numpy as np
import pandas as pd
from model import *
from utils import *
from tqdm import tqdm
from torch.utils.data import DataLoader
import pickle
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdmolops
from chiralfinder import ChiralFinder
import random
from collections import defaultdict
import os

def smiles_to_3d(smiles):
    """从 SMILES 生成 RDKit Mol 对象并构建 3D 坐标"""
    mol = Chem.MolFromSmiles(smiles)
    assert mol is not None
    mol = Chem.AddHs(mol)
    tag = "3d"
    try:
        # 生成 3D 构象
        options = AllChem.ETKDGv3()
        options.timeout = 60
        AllChem.EmbedMolecule(mol, options)
        AllChem.UFFOptimizeMolecule(mol)
        # conf = mol.GetConformer()
        # coords = np.array([list(conf.GetAtomPosition(i)) for i in range(mol.GetNumAtoms())])
    except Exception as e:
        print(f"3D 构建失败: {smiles[:20]}...  {e}, turn to 2D")
        # then generate 2D coordinates on the plane
        mol = Chem.RemoveHs(mol)
        rdDepictor.Compute2DCoords(mol)
        tag = "2d"
    assert mol.GetNumConformers() > 0, "No conformer found!"
    return mol, tag

def run_epoch(model, loader, device, loss_fn, optimizer=None, 
              use_orth_reg=False, reg_lambda=1.0, mode="train", save_path=None, dataset_name=None):
    if mode == "train":
        model.train()
    else:
        model.eval()
    total_loss, total_acc, count = 0.0, 0.0, 0
    test_res = [[], []] if mode == "test" else None
    test_preds, test_labels = [], []

    for batch in tqdm(loader, desc=f"{mode.capitalize()} Batches", leave=False):
        feats_q = batch['feats_q'].to(device)
        feats_q_kv = batch['feats_q_kv'].to(device)
        feats_k = batch['feats_k'].to(device)
        coords_q = batch['coords_q'].to(device)
        coords_k = batch['coords_k'].to(device)
        q_mask = batch['q_mask'].to(device)
        k_mask = batch['k_mask'].to(device)
        k_types = batch['k_atom_types'].to(device)
        edge_types_qk = batch['edge_types_qk'].to(device)
        labels = batch['labels'].to(device)
        weights = batch['weights'].to(device)

        if mode == "train":
            optimizer.zero_grad()

        output, _, loss_orth_reg = model(feats_q, feats_q_kv, feats_k, k_types, edge_types_qk,
                                         coords_q, coords_k, q_mask, k_mask)
        loss = loss_fn(output, labels.float())
        loss = torch.mean(loss * weights)
        if use_orth_reg:
            loss = loss + reg_lambda * loss_orth_reg
        if mode == "train":
            torch.autograd.set_detect_anomaly(True)
            loss.backward()
            for name, p in model.named_parameters():
                try:
                    if torch.isnan(p.grad).any() or torch.isinf(p.grad).any():
                        print("NaN in grad:", name)
                except:
                    pass
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
            optimizer.step()

        total_loss += loss.item() * len(labels)
        # now we need to compute the average accuracy for multi-label classification
        if dataset_name not in ["freesolv"]:
            total_acc += compute_accuracy_multitasks(output, labels) * labels.numel()
        count += labels.numel()

        test_preds.append(output.detach().cpu())
        test_labels.append(labels.cpu())
        if mode == "test":
            preds = (torch.sigmoid(output) > 0.5).float() if dataset_name not in ["freesolv"] else output
            test_res[0].extend(labels.cpu().numpy().tolist())
            test_res[1].extend(preds.detach().cpu().numpy().tolist())

    avg_loss = total_loss / count
    avg_acc = total_acc / count
    test_preds = torch.cat(test_preds, dim=0)
    test_labels = torch.cat(test_labels, dim=0)
    avg_rocauc = 0.
    if dataset_name not in ["freesolv"]:
        avg_rocauc, auc_list = compute_multitask_metrics(test_preds, test_labels)
    else:
        avg_acc = np.sqrt(avg_loss)
    # print("Test AUCs for each task:", auc_list)
    # print("Average Test AUC:", avg_rocauc)
    return avg_loss, avg_acc, avg_rocauc, test_res


def train_hct(train_data, val_data, test_data, args):
    # Dataset & Dataloader
    train_loader = DataLoader(train_data, batch_size=args.bs, shuffle=True, collate_fn=collate_hct)
    val_loader = DataLoader(val_data, batch_size=args.bs, shuffle=False, collate_fn=collate_hct)
    test_loader = DataLoader(test_data, batch_size=args.bs, shuffle=False, collate_fn=collate_hct)

    model = HCTModel(d_model=args.hidden_dim, n_heads=args.num_heads, num_layers=args.num_layers, proj_dim=args.proj_dim, 
                     chiral_encoder=args.chiral_encoder, use_qr=args.use_qr, ecd=False, num_classes=args.num_classes).to(args.device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    loss_fn = torch.nn.BCEWithLogitsLoss(reduction='none') if args.dataset not in ["freesolv"] else torch.nn.MSELoss(reduction='none')

    best_loss = [None, float('inf'), None]  # train_loss, val_loss, test_res
    for epoch in range(args.epochs):
        train_loss, train_acc, train_rocauc, _ = run_epoch(model, train_loader, args.device, loss_fn,
                                        optimizer, use_orth_reg=args.use_orth_reg, reg_lambda=args.reg_lambda, mode="train", dataset_name=args.dataset)
        val_loss, val_acc, val_rocauc, _ = run_epoch(model, val_loader, args.device, loss_fn,
                                    use_orth_reg=args.use_orth_reg, reg_lambda=args.reg_lambda, mode="val", dataset_name=args.dataset)
        test_loss, test_acc, test_rocauc, test_res = run_epoch(model, test_loader, args.device, loss_fn,
                                        use_orth_reg=args.use_orth_reg, reg_lambda=args.reg_lambda, mode="test", dataset_name=args.dataset)
        # save the best loss via val_loss and test_res
        if val_loss < best_loss[1]:
            best_loss = [train_loss, val_loss, test_res]
            prefix=f"epochs{args.epochs}_{args.bs}_{args.lr}_{args.min_lr}_{args.weight_decay}_qr_{args.hidden_dim}_{args.num_layers}_{args.proj_dim}_{args.chiral_encoder}"
            with open(f"/home/data/HCT/res/moleculenet/test_res/hct_{args.dataset}_{prefix}_test_res.pkl", "wb") as f:
                pickle.dump({
                    "epoch": epoch+1,
                    "train_loss": train_loss,
                    "train_acc": train_acc,
                    "train_rocauc": train_rocauc,
                    "val_loss": val_loss,
                    "val_acc": val_acc,
                    "val_rocauc": val_rocauc,
                    "test_loss": test_loss,
                    "test_acc": test_acc,
                    "test_rocauc": test_rocauc,
                    "test_labels": test_res[0],
                    "test_preds": test_res[1]
                }, f)

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f}, ROC-AUC: {train_rocauc:.4f} | "
            f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}, ROC-AUC: {val_rocauc:.4f} | "
            f"Test Loss: {test_loss:.4f}, Acc: {test_acc:.4f}, ROC-AUC: {test_rocauc:.4f}")

    return model


if __name__ == "__main__":
    args = get_args()

    def process_data(input_mols, labels, weights, tag, dataset_name):
        """
        df: pandas dataframe with columns
            'rdkit_mol_cistrans_stereo' : RDKit Mol with 3D coords
            'ranking label' : float
        Returns:
            data_list: list of dicts with keys
                'coords', 'atom_types', 'atom_onehot', 'atom_chiral', 'label'
        """
        error_idx = []
        none_chiral_idx = []
        data_list = []

        mols = []
        for mol in tqdm(input_mols, total=len(input_mols), desc="processing chiral", disable=not sys.stdout.isatty()):
            mol = Chem.AddHs(mol, addCoords=True)
            mols.append(mol)
        batch_size = 1000

        # if path exists, load directly
        pth = f"/home/data/HCT/data/moleculenet/{dataset_name}_{tag}_chiralfinder.pkl"
        if os.path.exists(pth):
            with open(pth, "rb") as f:
                res = pickle.load(f)
        else:
            res = []
            for start in tqdm(range(0, len(mols), batch_size), desc="batch", disable=not sys.stdout.isatty()):
                end = start + batch_size
                batch_mols = mols[start:end]
                cf_batch = ChiralFinder(batch_mols, "molecules")
                res.extend(cf_batch.get_central(n_cpus=4))
            # save
            with open(pth, "wb") as f:
                pickle.dump(res, f)

        for idx, mol in enumerate(tqdm(input_mols, total=len(input_mols), desc="processing data", disable=not sys.stdout.isatty())):
            label = labels[idx]
            weight = weights[idx]
            n_atoms = mol.GetNumAtoms()

            # 1. coordinates
            conf = mol.GetConformer()
            coords = np.array([list(conf.GetAtomPosition(i)) for i in range(n_atoms)], dtype=np.float32)

            # 2. atom_types: 0=chiral_center, 1=chiral_related, 2=non_related
            atom_types = np.ones(n_atoms, dtype=np.int64) * 2
            # in rare cases, the idx is different with mol_without_Hs
            rdkit_chiral_res = Chem.FindMolChiralCenters(mol, useLegacyImplementation=False)
            if len(rdkit_chiral_res) == 0:
                none_chiral_idx.append(idx)
            chiral_idx = [i for i,_ in rdkit_chiral_res]
            atom_types[chiral_idx] = 0

            # 3. chiral_related: neighbors of chiral centers
            adj = rdmolops.GetAdjacencyMatrix(mol)
            for idx_c in chiral_idx:
                neighbors = np.where(adj[idx_c] > 0)[0]
                for nb in neighbors:
                    if atom_types[nb] == 2:
                        atom_types[nb] = 1

            # 4. atom_'onehot': 52, like ChIRo
            atoms = rdkit.Chem.rdchem.Mol.GetAtoms(mol)
            atom_onehot = getNodeFeatures(atoms, mol, False)

            # 5. atom_chiral: 9-dim vector for chiral centers, zeros for others
            atom_chiral = np.zeros((n_atoms, 9), dtype=np.float32)
            # mol = Chem.AddHs(mol, addCoords=True)
            # cf = ChiralFinder([mol], "molecules")
            # res = cf.get_central(n_cpus=1)[0]
            # for i in range(len(res["center id"])):
            #     atom_chiral[res["center id"][i]] = np.array(res["quadrupole matrix"][i][0].reshape(9), dtype=np.float32)
            for i in range(len(res[idx]["center id"])):
                try:
                    atom_chiral[res[idx]["center id"][i]] = np.array(res[idx]["quadrupole matrix"][i][0].reshape(9), dtype=np.float32)
                except:
                    error_idx.append(idx)
                    # print(f"Error in {tag} atom_chiral for mol idx: {idx}")
                    pass
            # 6. edge_type: (n_atoms, n_atoms)
            edge_types = np.full((n_atoms, n_atoms), 0, dtype=np.int64)
            for i in range(n_atoms):
                for j in range(n_atoms):
                    if i == j:
                        continue
                    # chiral atom
                    if (atom_types[i] == 0 and atom_types[j] == 1) or (atom_types[i] == 1 and atom_types[j] == 0):
                        edge_types[i, j] = 1
                    elif (atom_types[i] == 0 and atom_types[j] == 2) or (atom_types[i] == 2 and atom_types[j] == 0):
                        edge_types[i, j] = 2

            data_list.append({
                'coords': coords,
                'atom_types': atom_types,
                'edge_types': edge_types,
                'atom_onehot': atom_onehot,
                'atom_chiral': atom_chiral,
                'rdkit_mol': mol,
                'label': label,
                'weight': weight
            })
        print(f"{tag} set: total {len(input_mols)} molecules, {len(error_idx)} molecules with failed chiral matrix.")
        print(f"{tag} set: {len(none_chiral_idx)} molecules with no chiral centers.")
        return data_list

    # bbbp, sider, clintox, bace, Freesolv
    def run_dataset(dataset_name):
        if dataset_name == "bbbp":
            tasks, datasets, transformers = dc.deepchem.molnet.load_bbbp(featurizer='Raw', splitter='scaffold', reload=True, 
                                                                     data_dir="/home/data/HCT/data/moleculenet", save_dir="/home/data/HCT/data/moleculenet")
        elif dataset_name == "sider":
            tasks, datasets, transformers = dc.deepchem.molnet.load_sider(featurizer='Raw', splitter='scaffold', reload=True, 
                                                                     data_dir="/home/data/HCT/data/moleculenet", save_dir="/home/data/HCT/data/moleculenet")
        elif dataset_name == "clintox":
            tasks, datasets, transformers = dc.deepchem.molnet.load_clintox(featurizer='Raw', splitter='scaffold', reload=True, 
                                                                     data_dir="/home/data/HCT/data/moleculenet", save_dir="/home/data/HCT/data/moleculenet")
        elif dataset_name == "freesolv":
            tasks, datasets, transformers = dc.deepchem.molnet.load_sampl(featurizer='Raw', splitter='scaffold', transformers="", reload=True, 
                                                                     data_dir="/home/data/HCT/data/moleculenet", save_dir="/home/data/HCT/data/moleculenet")
        elif dataset_name == "bace":
            tasks, datasets, transformers = dc.deepchem.molnet.load_bace_classification(featurizer='Raw', splitter='scaffold', reload=True, 
                                                                     data_dir="/home/data/HCT/data/moleculenet", save_dir="/home/data/HCT/data/moleculenet")
        else:
            raise ValueError("Invalid dataset name!")
        
        train, valid, test = datasets

        def process_dataset(dataset, tasks_num):
            df = dataset.to_dataframe()
            print(df.head())

            smiles_list = df["ids"].tolist()
            if tasks_num == 1:
                labels = df["y"].tolist()
                weights = df["w"].tolist()
            else:
                # y1-yxx, for multiple tasks
                labels = df[[f"y{i+1}" for i in range(tasks_num)]].values.tolist()
                weights = df[[f"w{i+1}" for i in range(tasks_num)]].values.tolist()

            mols = []
            tags_ = []
            for smi in tqdm(smiles_list):
                mol, tag = smiles_to_3d(smi)
                tags_.append(tag)
                mols.append(mol)
            print(f"Dataset {dataset_name}, total {len(mols)} molecules, {tags_.count('2d')} molecules with 2D coordinates.")
            return mols, labels, weights
        
        pth = f"/home/data/HCT/data/moleculenet/{dataset_name}_with_3d.pkl"
        if os.path.exists(pth):
            with open(pth, "rb") as f:
                data = pickle.load(f)
            train_mols, train_labels, train_weights = data["train"]
            valid_mols, valid_labels, valid_weights = data["val"]
            test_mols, test_labels, test_weights = data["test"]
        else:
            train_mols, train_labels, train_weights = process_dataset(train, len(tasks))
            valid_mols, valid_labels, valid_weights = process_dataset(valid, len(tasks))
            test_mols, test_labels, test_weights = process_dataset(test, len(tasks))
            # save processed data, together
            with open(pth, "wb") as f:
                pickle.dump({
                    "train": (train_mols, train_labels, train_weights),
                    "val": (valid_mols, valid_labels, valid_weights),
                    "test": (test_mols, test_labels, test_weights)
                }, f)

        train_data = process_data(train_mols, train_labels, train_weights, tag="train", dataset_name=dataset_name)
        val_data = process_data(valid_mols, valid_labels, valid_weights, tag="val", dataset_name=dataset_name)
        test_data = process_data(test_mols, test_labels, test_weights, tag="test", dataset_name=dataset_name)

        return train_data, val_data, test_data, len(tasks)

    # if too slow for some mol, add timeout, to 2d
    train_data, val_data, test_data, tasks_num = run_dataset(args.dataset)
    # run_dataset("bbbp")
    # run_dataset("sider") 
    # run_dataset("clintox")
    # run_dataset("freesolv")
    # run_dataset("bace")

    args.num_classes = tasks_num # each 0/1
    print("number of tasks: ", tasks_num)
    model = train_hct(
        train_data, val_data, test_data, args
        )
    checkpoint = {
        "model_state_dict": model.state_dict(),
    }
    torch.save(checkpoint, f"/amax/data/HCT/res/moleculenet/ckpt_epoch_0_{args.dataset}.pth")
