from model import *
from utils import *
from tqdm import tqdm
from torch.utils.data import DataLoader
import pickle
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdmolops
from chiralfinder import ChiralFinder
import random
from collections import defaultdict
import os


def run_epoch(model, loader, device, loss_fn, optimizer=None, 
              use_orth_reg=False, reg_lambda=1.0, mode="train", save_path=None):
    if mode == "train":
        model.train()
    else:
        model.eval()
    total_loss, total_acc, count = 0.0, 0.0, 0
    test_res = [[], []] if mode == "test" else None

    for batch in tqdm(loader, desc=f"{mode.capitalize()} Batches", leave=False):
        feats_q = batch['feats_q'].to(device)
        feats_q_kv = batch['feats_q_kv'].to(device)
        feats_k = batch['feats_k'].to(device)
        coords_q = batch['coords_q'].to(device)
        coords_k = batch['coords_k'].to(device)
        q_mask = batch['q_mask'].to(device)
        k_mask = batch['k_mask'].to(device)
        k_types = batch['k_atom_types'].to(device)
        edge_types_qk = batch['edge_types_qk'].to(device)
        labels = batch['labels'].to(device)

        if mode == "train":
            optimizer.zero_grad()

        output, _, loss_orth_reg = model(feats_q, feats_q_kv, feats_k, k_types, edge_types_qk,
                                         coords_q, coords_k, q_mask, k_mask)
        loss = loss_fn(output, labels.long())
        if use_orth_reg:
            loss = loss + reg_lambda * loss_orth_reg
        if mode == "train":
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
            optimizer.step()

        total_loss += loss.item() * len(labels)
        total_acc += compute_accuracy_multiclasses(output, labels) * len(labels)
        count += len(labels)

        if mode == "test":
            preds = torch.argmax(output, dim=1)
            # preds = (torch.sigmoid(output) > 0.5).float()
            test_res[0].extend(labels.cpu().numpy().tolist())
            test_res[1].extend(preds.detach().cpu().numpy().tolist())

    avg_loss = total_loss / count
    avg_acc = total_acc / count

    return avg_loss, avg_acc, test_res


def train_hct(train_data, val_data, test_data, args):
    device=args.device
    epochs=args.epochs
    batch_size=args.bs
    lr=args.lr
    weight_decay=args.weight_decay
    use_qr=args.use_qr
    reg_lambda=args.reg_lambda
    use_orth_reg=args.use_orth_reg
    # Dataset & Dataloader
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_hct)
    val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, collate_fn=collate_hct)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, collate_fn=collate_hct)

    model = HCTModel(d_model=args.hidden_dim, n_heads=args.num_heads, num_layers=args.num_layers, proj_dim=args.proj_dim, 
                     chiral_encoder=args.chiral_encoder, use_qr=args.use_qr, ecd=False, num_classes=args.num_classes).to(args.device)    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    loss_fn = torch.nn.CrossEntropyLoss()

    best_loss = [None, float('inf'), None]  # train_loss, val_loss, test_res
    for epoch in range(epochs):
        train_loss, train_acc, _ = run_epoch(model, train_loader, device, loss_fn,
                                        optimizer, use_orth_reg=use_orth_reg, reg_lambda=reg_lambda, mode="train")
        val_loss, val_acc, _ = run_epoch(model, val_loader, device, loss_fn,
                                    use_orth_reg=use_orth_reg, reg_lambda=reg_lambda, mode="val")
        test_loss, test_acc, test_res = run_epoch(model, test_loader, device, loss_fn,
                                        use_orth_reg=use_orth_reg, reg_lambda=reg_lambda, mode="test")
        # save the best loss via val_loss and test_res
        if val_loss < best_loss[1]:
            best_loss = [train_loss, val_loss, test_res]
            with open(f"/home/data/HCT/res/hct_RS_test_res.pkl", "wb") as f:
                pickle.dump({
                    "epoch": epoch+1,
                    "train_loss": train_loss,
                    "train_acc": train_acc,
                    "val_loss": val_loss,
                    "val_acc": val_acc,
                    "test_loss": test_loss,
                    "test_acc": test_acc,
                    "test_labels": test_res[0],
                    "test_preds": test_res[1]
                }, f)

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
            f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f} | "
            f"Test Loss: {test_loss:.4f}, Acc: {test_acc:.4f}")

    return model


if __name__ == '__main__':
    args = get_args()

    def process_data(df, tag, random_ratio):
        """
        df: pandas dataframe with columns
            'rdkit_mol_cistrans_stereo' : RDKit Mol with 3D coords
            'RS_label_binary' : 0/1
        Returns:
            data_list: list of dicts with keys
                'coords', 'atom_types', 'atom_onehot', 'atom_chiral', 'label'
        """
        data_list = []

        mols = []
        for idx, row in tqdm(df.iterrows(), total=len(df), desc="processing chiral"):
            mol = row['rdkit_mol_cistrans_stereo']
            mol = Chem.AddHs(mol, addCoords=True)
            mols.append(mol)
        batch_size = 500

        # if path exists, load directly
        pth = f"/home/data/HCT/data/RS/chiralfinder_{tag}.pkl"
        if os.path.exists(pth):
            with open(pth, "rb") as f:
                res = pickle.load(f)
        else:
            res = []
            for start in tqdm(range(0, len(mols), batch_size), desc="batch", disable=not sys.stdout.isatty()):
                end = start + batch_size
                batch_mols = mols[start:end]
                cf_batch = ChiralFinder(batch_mols, "molecules")
                res.extend(cf_batch.get_central(n_cpus=4))
            # save
            with open(pth, "wb") as f:
                pickle.dump(res, f)

        # random_ratio, select some idx without duplicate
        n_random = int(len(df) * random_ratio)
        random_idx = set(random.sample(range(len(df)), n_random))
        for idx, row in tqdm(df.iterrows(), total=len(df), desc="processing data"):
            mol = row['rdkit_mol_cistrans_stereo']
            label = float(row['RS_label_binary'])
            n_atoms = mol.GetNumAtoms()

            # 1. coordinates
            conf = mol.GetConformer()
            coords = np.array([list(conf.GetAtomPosition(i)) for i in range(n_atoms)], dtype=np.float32)

            # 2. atom_types: 0=chiral_center, 1=chiral_related, 2=non_related
            atom_types = np.ones(n_atoms, dtype=np.int64) * 2
            chiral_idx = sorted([i for i,_ in Chem.FindMolChiralCenters(mol, useLegacyImplementation=False)])
            
            # if random, 1/2 delete one, 1/2 replace with one idx not in the original idx
            non_chiral = [i for i in range(n_atoms) if i not in chiral_idx]
            if idx in random_idx and len(chiral_idx) > 0:
                r = random.random()
                if r < 1/2:
                    drop_idx = random.choice(chiral_idx)
                    chiral_idx.remove(drop_idx)
                else:
                    drop_idx = random.choice(chiral_idx)
                    chiral_idx.remove(drop_idx)
                    if len(non_chiral) > 0:
                        add_idx = random.choice(non_chiral)
                        chiral_idx.append(add_idx)

            atom_types[chiral_idx] = 0

            # 3. chiral_related: neighbors of chiral centers
            adj = rdmolops.GetAdjacencyMatrix(mol)
            for idx_c in chiral_idx:
                neighbors = np.where(adj[idx_c] > 0)[0]
                for nb in neighbors:
                    if atom_types[nb] == 2:
                        atom_types[nb] = 1

            # 4. atom_onehot: atomic number % 100
            atom_onehot = np.eye(100, dtype=np.float32)[[atom.GetAtomicNum()%100 for atom in mol.GetAtoms()]]
            # 4. atom_'onehot': 52, like ChIRo
            atoms = rdkit.Chem.rdchem.Mol.GetAtoms(mol)
            atom_onehot = getNodeFeatures(atoms, mol, False)

            # 5. atom_chiral: 9-dim vector for chiral centers, zeros for others
            atom_chiral = np.zeros((n_atoms, 9), dtype=np.float32)

            for i in range(len(res[idx]["center id"])):
                if res[idx]["center id"][i] in chiral_idx:
                    atom_chiral[res[idx]["center id"][i]] = np.array(res[idx]["quadrupole matrix"][i][0].reshape(9), dtype=np.float32)
            # replaced, compute the matrix
            for i in range(len(chiral_idx)):
                if chiral_idx[i] not in res[idx]["center id"]:
                    neigh_cor = []
                    atom = mol.GetAtomWithIdx(chiral_idx[i])
                    neighbors = [nbr.GetIdx() for nbr in atom.GetNeighbors()]
                    for nb in neighbors:
                        pos = mol.GetConformer().GetAtomPosition(nb)
                        neigh_cor.append(np.array([pos.x, pos.y, pos.z], dtype=np.float32))
                    pos = mol.GetConformer().GetAtomPosition(chiral_idx[i])
                    neigh_cor.insert(0, np.array([pos.x, pos.y, pos.z], dtype=np.float32))
                    # get the matrix
                    if len(neigh_cor) == 4:
                        neigh_cor.insert(1, (neigh_cor[1]+neigh_cor[2]+neigh_cor[3]-neigh_cor[0]*3)/3*-1.0+neigh_cor[0])
                    elif len(neigh_cor) < 4:
                        # pad with 0 vectors
                        for _ in range(5-len(neigh_cor)):
                            neigh_cor.append(np.array([0.0, 0.0, 0.0], dtype=np.float32))
                    a = neigh_cor[1] - neigh_cor[0]
                    b = neigh_cor[4] - neigh_cor[3]
                    c = neigh_cor[4] - neigh_cor[2]
                    # cp_max = np.linalg.norm(np.cross(a, b)) * np.linalg.norm(c)
                    mat = np.array([a, b, c])
                    atom_chiral[chiral_idx[i]] = np.array(mat.reshape(9), dtype=np.float32)

            # 6. edge_type: (n_atoms, n_atoms)
            edge_types = np.full((n_atoms, n_atoms), 0, dtype=np.int64)
            for i in range(n_atoms):
                for j in range(n_atoms):
                    if i == j:
                        continue
                    # chiral atom
                    if (atom_types[i] == 0 and atom_types[j] == 1) or (atom_types[i] == 1 and atom_types[j] == 0):
                        edge_types[i, j] = 1
                    elif (atom_types[i] == 0 and atom_types[j] == 2) or (atom_types[i] == 2 and atom_types[j] == 0):
                        edge_types[i, j] = 2

            data_list.append({
                'coords': coords,
                'atom_types': atom_types,
                'edge_types': edge_types,
                'atom_onehot': atom_onehot,
                'atom_chiral': atom_chiral,
                'label': label,
                "rdkit_mol": mol
            })

        return data_list
    
    params_train = "/amax/data/chirality/center_RS_classification/train_RS_classification_enantiomers_MOL_326865_55084_27542.pkl"
    params_validation = "/amax/data/chirality/center_RS_classification/validation_RS_classification_enantiomers_MOL_70099_11748_5874.pkl"
    params_test = "/amax/data/chirality/center_RS_classification/test_RS_classification_enantiomers_MOL_69719_11680_5840.pkl"
    train_df = pd.read_pickle(params_train)
    val_df = pd.read_pickle(params_validation)
    test_df = pd.read_pickle(params_test)

    train_data = process_data(train_df, "train", args.random_ratio)
    val_data = process_data(val_df, "val", args.random_ratio)
    test_data = process_data(test_df, "test", args.random_ratio)

    args.num_classes = 2
    model = train_hct(
        train_data, val_data, test_data, args
        )
    checkpoint = {
        "model_state_dict": model.state_dict(),
    }
    torch.save(checkpoint, "/amax/data/HCT/res/RS/ckpt_epoch_0.pth")
