from model import *
from utils import *
from tqdm import tqdm
from torch.utils.data import DataLoader
import pickle
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdmolops
from chiralfinder import ChiralFinder, ChiralCenter
import random
from collections import defaultdict
import os


def run_epoch(model, loader, device, loss_fn, optimizer=None, 
              use_orth_reg=False, reg_lambda=1.0, mode="train", save_path=None):
    if mode == "train":
        model.train()
    else:
        model.eval()
    total_loss, total_acc, count = 0.0, 0.0, 0
    test_res = [[], []] if mode == "test" else None

    for batch in tqdm(loader, desc=f"{mode.capitalize()} Batches", leave=False):
        feats_q = batch['feats_q'].to(device)
        feats_q_kv = batch['feats_q_kv'].to(device)
        feats_k = batch['feats_k'].to(device)
        coords_q = batch['coords_q'].to(device)
        coords_k = batch['coords_k'].to(device)
        q_mask = batch['q_mask'].to(device)
        k_mask = batch['k_mask'].to(device)
        k_types = batch['k_atom_types'].to(device)
        edge_types_qk = batch['edge_types_qk'].to(device)
        labels = batch['labels'].to(device).long()

        if mode == "train":
            optimizer.zero_grad()

        output, _, loss_orth_reg = model(feats_q, feats_q_kv, feats_k, k_types, edge_types_qk,
                                         coords_q, coords_k, q_mask, k_mask)
        loss = loss_fn(output, labels)
        if use_orth_reg:
            loss = loss + reg_lambda * loss_orth_reg
        if mode == "train":
            torch.autograd.set_detect_anomaly(True)
            loss.backward()
            for name, p in model.named_parameters():
                try:
                    if torch.isnan(p.grad).any() or torch.isinf(p.grad).any():
                        print("NaN in grad:", name)
                except:
                    print("Error checking grad:", name)
                    pass
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
            optimizer.step()

        total_loss += loss.item() * len(labels)
        total_acc += compute_accuracy_multiclasses(output, labels) * len(labels)
        count += len(labels)

        if mode == "test":
            preds = torch.argmax(output, dim=1)
            test_res[0].extend(labels.cpu().numpy().tolist())
            test_res[1].extend(preds.detach().cpu().numpy().tolist())

    avg_loss = total_loss / count
    avg_acc = total_acc / count

    return avg_loss, avg_acc, test_res


def train_hct(train_data, val_data, test_data, args):
    # Dataset & Dataloader
    train_loader = DataLoader(train_data, batch_size=args.bs, shuffle=True, collate_fn=collate_hct)
    val_loader = DataLoader(val_data, batch_size=args.bs, shuffle=False, collate_fn=collate_hct)
    test_loader = DataLoader(test_data, batch_size=args.bs, shuffle=False, collate_fn=collate_hct)

    model = HCTModel(d_model=args.hidden_dim, n_heads=args.num_heads, num_layers=args.num_layers, proj_dim=args.proj_dim, 
                     chiral_encoder=args.chiral_encoder, use_qr=args.use_qr, ecd=False, num_classes=args.num_classes).to(args.device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    loss_fn = torch.nn.CrossEntropyLoss()

    best_loss = [None, float('inf'), None]  # train_loss, val_loss, test_res
    for epoch in range(args.epochs):
        train_loss, train_acc, _ = run_epoch(model, train_loader, args.device, loss_fn,
                                        optimizer, use_orth_reg=args.use_orth_reg, reg_lambda=args.reg_lambda, mode="train")
        val_loss, val_acc, _ = run_epoch(model, val_loader, args.device, loss_fn,
                                    use_orth_reg=args.use_orth_reg, reg_lambda=args.reg_lambda, mode="val")
        test_loss, test_acc, test_res = run_epoch(model, test_loader, args.device, loss_fn,
                                        use_orth_reg=args.use_orth_reg, reg_lambda=args.reg_lambda, mode="test")
        # save the best loss via val_loss and test_res
        if val_loss < best_loss[1]:
            best_loss = [train_loss, val_loss, test_res]
            prefix=f"epochs{args.epochs}_{args.bs}_{args.lr}_{args.min_lr}_{args.weight_decay}_qr_{args.hidden_dim}_{args.num_layers}_{args.proj_dim}_{args.chiral_encoder}"
            with open(f"/home/data/HCT/res/optical_axial/test_res/hct_{args.dataset}_{prefix}_test_res.pkl", "wb") as f:
                pickle.dump({
                    "epoch": epoch+1,
                    "train_loss": train_loss,
                    "train_acc": train_acc,
                    "val_loss": val_loss,
                    "val_acc": val_acc,
                    "test_loss": test_loss,
                    "test_acc": test_acc,
                    "test_labels": test_res[0],
                    "test_preds": test_res[1]
                }, f)

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
            f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f} | "
            f"Test Loss: {test_loss:.4f}, Acc: {test_acc:.4f}")

    return model


if __name__ == '__main__':
    args = get_args()

    def process_data(data, random_ratio):
        data_list = []

        # chiral_finder = ChiralFinder([one["rdkit_mol"] for one in data], "molecules")
        # res = chiral_finder.get_axial()
        # with open("/amax/data/HCT/data/hct_ecd_axial_res.pkl", "wb") as f:
        #     pickle.dump(res, f)
        with open("/home/data/HCT/data/hct_ecd_axial_res.pkl", "rb") as f:
            res = pickle.load(f)
        labels = pd.read_excel("/amax/data/chirality/axial/axial_650.xlsx")
        optical_labels = pd.read_csv("/home/data/HCT/data/optical_axial/optical_rotation_589nm.csv")
        id2optical = {}
        for i in range(len(optical_labels)):
            id2optical[int(optical_labels["id"][i])] = float(optical_labels["OR_589nm"][i])
        coverage_scores = []
        IoU_scores = []

        def normalize_pairs(pairs):
            """Normalize tuples so (1,2) == (2,1)"""
            return {tuple(sorted(p)) for p in pairs}

        def compute_coverage_and_iou(label, pred):
            # Normalize
            label_set = normalize_pairs(label)
            pred_set = normalize_pairs(pred)

            # Intersection
            inter = label_set & pred_set

            # Union
            union = label_set | pred_set

            # Coverage: predict all label elements
            coverage = 1 if inter == label_set else 0

            # IoU
            iou = len(inter) / len(union) if len(union) > 0 else 1.0

            return coverage, iou


        # random_ratio, select some idx without duplicate
        n_random = int(len(data) * random_ratio)
        random_idx = set(random.sample(range(len(data)), n_random))
        for idx, row in enumerate(tqdm(data, total=len(data), desc="processing data", disable=not sys.stdout.isatty())):
            one_label_id = abs(int(row["id"].split("_")[-1]))
            # one_label = eval(labels["label_expanded"][one_label_id])
            # if not one_label:
            one_label = eval(labels["label"][one_label_id])
            one_pred = res[idx]["chiral axes"]
            # compute coverage and IoU scores
            coverage, iou = compute_coverage_and_iou(one_label, one_pred)
            coverage_scores.append(coverage)
            IoU_scores.append(iou)
            chiral_type = labels["chiral_type"][one_label_id]

            mol = row['rdkit_mol']
            id_wo_abs_str = row["id"].split("_")[-1]
            if "-" in id_wo_abs_str:
                label = 1 - (1 if id2optical[one_label_id] > 0 else 0)
            else:
                label = 1 if id2optical[one_label_id] > 0 else 0
            n_atoms = mol.GetNumAtoms()

            # 1. coordinates
            conf = mol.GetConformer()
            coords = np.array([list(conf.GetAtomPosition(i)) for i in range(n_atoms)], dtype=np.float32)

            # 2. atom_types: 0=chiral_center, 1=chiral_related, 2=non_related
            atom_types = np.ones(n_atoms, dtype=np.int64) * 2
            # chiral_idx = list(set([x for t in res[idx]["chiral axes"] for x in t]))
            if random_ratio > 0.05:
                one_label = res[idx]["chiral axes"]
            chiral_idx = [x for t in one_label for x in t]
            if not chiral_idx:
                print(f"Warning: no chiral center found in molecule {idx}, SMILES: {Chem.MolToSmiles(mol)}")

            atom_types[chiral_idx] = 0

            # 3. chiral_related: neighbors of chiral centers

            edge_types = np.full((n_atoms, n_atoms), 0, dtype=np.int64)

            # 4. atom_'onehot': 52, like ChIRo
            atoms = rdkit.Chem.rdchem.Mol.GetAtoms(mol)
            atom_onehot = getNodeFeatures(atoms, mol, False)

            # 5. atom_chiral: 9-dim vector for chiral centers, zeros for others
            atom_chiral = np.zeros((n_atoms, 9), dtype=np.float32)
            tag = False
            for one_axial in tqdm(one_label, desc=f"{idx}", leave=False, disable=not sys.stdout.isatty()):
                for i in range(len(res[idx]["chiral axes"])):
                    # match
                    if set(one_axial) == set(res[idx]["chiral axes"][i]):
                        tag = True
                        # neighbor
                        neighbors = res[idx]["neighbor ids"][i]
                        if isinstance(neighbors[0], list):
                            neighbors = sum(neighbors, [])
                        for nb in neighbors:
                            if nb < n_atoms and atom_types[nb] == 2:  # omit H atom
                                atom_types[nb] = 1
                                for one_chiral_id in one_axial:
                                    edge_types[nb, one_chiral_id] = 1
                                    edge_types[one_chiral_id, nb] = 1
                        # chiral matrix
                        if isinstance(res[idx]["quadrupole matrix"][i][0], list):
                            for k, j in enumerate(res[idx]["chiral axes"][i]):
                                atom_chiral[j] = np.array(res[idx]["quadrupole matrix"][i][k][0].reshape(9), dtype=np.float32)
                        else:
                            for j in res[idx]["chiral axes"][i]:
                                atom_chiral[j] = np.array(res[idx]["quadrupole matrix"][i][0].reshape(9), dtype=np.float32)
                        break
            if not tag:
                print(f"Warning: no chiral axis matched in molecule {idx}, SMILES: {Chem.MolToSmiles(mol)}")
                # 对于每个label，如[(12, 14)]里的12和14（可能有一个或多个），取这几个原子的邻居，取前四个邻居（set，且去掉这几个原子本身），加上这几个原子的中心，构成[0,1,2,3,4]5个原子；assert一下没有四个邻居的情况
                # 取出坐标，计算，a = neigh_cor[1] - neigh_cor[0]，b = neigh_cor[2] - neigh_cor[0]，c = neigh_cor[3] - neigh_cor[4]，作为这几个chiral原子的atom_chiral特征，同时这四个邻居按照前面for nb in neighbors的操作处理
                mol = rdmolops.AddHs(mol, addCoords=True)
                conf = mol.GetConformer()
                coords = np.array([list(conf.GetAtomPosition(i)) for i in range(mol.GetNumAtoms())], dtype=np.float32)
                for one_axial in one_label:
                    neigh_set = set()
                    for one_chiral_id in one_axial:
                        neighbors = mol.GetAtomWithIdx(one_chiral_id).GetNeighbors()
                        neighbors = [x.GetIdx() for x in neighbors if x.GetIdx() not in one_axial]
                        for nb in neighbors:
                            neigh_set.add(nb)
                    neigh_list = list(neigh_set)
                    assert len(neigh_list) >= 3
                    neigh_list = sorted(neigh_list)[:4]
                    # center
                    center = np.mean([coords[x] for x in one_axial], axis=0)
                    neigh_cor = [center] + [coords[x] for x in neigh_list]
                    # 只有三个邻居的情况，补三个邻居均值取反
                    if len(neigh_list) == 3:
                        neigh_cor.append(-np.mean([coords[x] for x in neigh_list], axis=0))
                    a = neigh_cor[1] - neigh_cor[0]
                    b = neigh_cor[2] - neigh_cor[0]
                    c = neigh_cor[3] - neigh_cor[4]
                    # cp_max = np.linalg.norm(np.cross(a, b)) * np.linalg.norm(c)
                    mat = np.array([a, b, c])
                    for j in one_axial:
                        atom_chiral[j] = mat.reshape(9)
                    for nb in neigh_list:
                        if nb < n_atoms and atom_types[nb] == 2:  # omit H atom
                            atom_types[nb] = 1
                            for one_chiral_id in one_axial:
                                edge_types[nb, one_chiral_id] = 1
                                edge_types[one_chiral_id, nb] = 1

            data_list.append({
                'coords': coords,
                'atom_types': atom_types,
                'edge_types': edge_types,
                'atom_onehot': atom_onehot,
                'atom_chiral': atom_chiral,
                'label': label,
                "rdkit_mol": mol,
                'chiral_type': chiral_type
            })

        print(f"Average Coverage: {np.mean(coverage_scores)}, average IoU: {np.mean(IoU_scores)}")
        return data_list
    
    params_p = "/amax/data/HCT/data/hct_ecd_axial.pkl"
    data = pd.read_pickle(params_p)

    data = process_data(data, random_ratio=args.random_ratio)
    with open("/amax/data/HCT/data/ecd_axial_index_split.pkl", "rb") as f:
        index_ = pickle.load(f)
        train_index, val_index, test_index = index_["train_index"], index_["val_index"], index_["test_index"]
    
    train_data = [data[i] for i in train_index]
    val_data = [data[i] for i in val_index]
    test_data = [data[i] for i in test_index]

    # with open("/home/data/HCT/data/optical_axial/split.pkl", "wb") as f:
    #     pickle.dump({
    #         "train": train_data,
    #         "val": val_data,
    #         "test": test_data,
    #     }, f)

    args.num_classes = 2  # optical is binary classification
    model = train_hct(
        train_data, val_data, test_data, args
        )
    checkpoint = {
        "model_state_dict": model.state_dict(),
    }
    torch.save(checkpoint, "/amax/data/HCT/res/optical_axial/ckpt_epoch_0.pth")
