from model import *
from utils import *
from tqdm import tqdm
from torch.utils.data import DataLoader
import pickle
import pandas as pd
from rdkit import Chem
from rdkit.Chem import rdmolops
from chiralfinder import ChiralFinder
import random
from collections import defaultdict


def run_epoch(model, loader, device, optimizer=None, 
              use_orth_reg=False, reg_lambda=1.0, mode="train", 
              margin=0.3, save_path=None):
    if mode == "train":
        model.train()
    else:
        model.eval()
    count = 0
    loss_accum = 0
    loss_accum_3 = [0, 0, 0]  # for number, position, height
    test_res = {"num": [[], []], "position": [[], []], "height": [[], []]} if mode=="test" else None # labels, preds
    total_number = []
    total_position = []
    total_height = []

    ce_loss = torch.nn.CrossEntropyLoss()

    for batch in tqdm(loader, desc=f"{mode.capitalize()} Batches", leave=False, disable=not sys.stdout.isatty()):
        feats_q = batch['feats_q'].to(device)
        feats_q_kv = batch['feats_q_kv'].to(device)
        feats_k = batch['feats_k'].to(device)
        coords_q = batch['coords_q'].to(device)
        coords_k = batch['coords_k'].to(device)
        q_mask = batch['q_mask'].to(device)
        k_mask = batch['k_mask'].to(device)
        k_types = batch['k_atom_types'].to(device)
        edge_types_qk = batch['edge_types_qk'].to(device)
        num_gt = batch['labels_num'].to(device)
        pos_gt = batch['labels_position'].to(device)
        height_gt = batch['labels_height'].to(device)

        # training the model for height, position, and number
        num_pred, pos_pred, height_pred, _, loss_orth_reg = model(feats_q, feats_q_kv, feats_k, k_types, edge_types_qk,
                                         coords_q, coords_k, q_mask, k_mask)
        pos_pred = pos_pred.view(-1, 7, 20)  # [batch, 7, 20]
        height_pred = height_pred.view(-1, 7, 2)  # [batch, 7, 2]
        # transform the groundtruth and prediction labels
        new_gt_pos, new_pred_pos = [], []
        new_gt_height, new_pred_height = [], []
        for i in range(num_gt.size(0)):
            if num_gt[i] == 0:
                # just give the same one num
                new_gt_pos.append(torch.zeros((7,), dtype=torch.int64).to(device))
                temp_ = torch.zeros((7, 20)).to(device)
                temp_[:, 0] = 1.0
                new_pred_pos.append(temp_)
                new_gt_height.append(torch.zeros((7,), dtype=torch.int64).to(device))
                temp_ = torch.zeros((7, 2)).to(device)
                temp_[:, 0] = 1.0
                new_pred_height.append(temp_)
            else:
                new_gt_pos.append(pos_gt[i, :int(num_gt[i])])
                new_pred_pos.append(pos_pred[i, :int(num_gt[i]), :])
                new_gt_height.append(height_gt[i, :int(num_gt[i])])
                new_pred_height.append(height_pred[i, :int(num_gt[i]), :])
        
        new_gt_pos_tensor = torch.cat(new_gt_pos, dim=0)            # [batch*node_num]
        new_pred_pos_tensor = torch.cat(new_pred_pos, dim=0)        # [batch*node_num, 20]
        new_gt_height_tensor = torch.cat(new_gt_height, dim=0)      # [batch*node_num]
        new_pred_height_tensor = torch.cat(new_pred_height, dim=0)    # [batch*node_num, 2]
        assert new_gt_pos_tensor.max().item() < 20
        assert new_gt_height_tensor.max().item() < 2

        # backward propagation
        loss_pos = ce_loss(new_pred_pos_tensor, new_gt_pos_tensor)
        loss_height = ce_loss(new_pred_height_tensor, new_gt_height_tensor)
        loss_num = ce_loss(num_pred, num_gt)
        loss = loss_num + 2*loss_height + loss_pos

        if use_orth_reg:
            loss = loss + reg_lambda * loss_orth_reg
        if mode == "train":
            optimizer.zero_grad()
        if mode == "train":
            torch.autograd.set_detect_anomaly(True)
            loss.backward()
            for name, p in model.named_parameters():
                try:
                    if torch.isnan(p.grad).any() or torch.isinf(p.grad).any():
                        print("NaN in grad:", name)
                except:
                    pass
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=10)
            optimizer.step()

        loss_accum += loss.detach().cpu().item() * num_gt.size(0)
        # if nan, then debug
        if np.isnan(loss_accum):
            print("Error: loss is nan!")
        count += num_gt.size(0)
        loss_accum_3[0] += loss_num.detach().cpu().item() * num_gt.size(0)
        loss_accum_3[1] += loss_pos.detach().cpu().item() * num_gt.size(0)
        loss_accum_3[2] += loss_height.detach().cpu().item() * num_gt.size(0)

        # number rmse
        _, num_preds = torch.max(num_pred, dim=1)
        num_rmse = torch.sqrt(torch.mean((num_preds - num_gt).float() ** 2)).detach().cpu().item()
        total_number.append(num_rmse)
        # position rmse, 算pos的时候，取min(gt, pred_num)个位置来算
        _, pos_preds = torch.max(new_pred_pos_tensor, dim=1)

        pos_rmse_list = []
        pos_offset = 0
        for i in range(num_gt.size(0)):
            gt_num = int(num_gt[i].item())
            # 预测的数量
            _, pred_num = torch.max(num_pred[i].unsqueeze(0), dim=1)
            pred_num = int(pred_num.item())
            n = min(gt_num, pred_num)
            if n == 0:
                pos_rmse_list.append(torch.as_tensor(0.0).item())
                continue
            # 取前n个位置
            gt_pos = new_gt_pos_tensor[pos_offset:pos_offset + gt_num][:n]
            pred_pos = pos_preds[pos_offset:pos_offset + gt_num][:n]
            pos_rmse_list.append(torch.sqrt(torch.mean((pred_pos - gt_pos).float() ** 2)).item())
            pos_offset += gt_num
        pos_rmse = np.mean(pos_rmse_list) if pos_rmse_list else 0.0
        total_position.append(pos_rmse)

        # height accuracy
        _, height_preds = torch.max(new_pred_height_tensor, dim=1)
        # height_acc = (height_preds == new_gt_height_tensor).sum().detach().cpu().item()
        # height_acc = height_acc / new_gt_height_tensor.size(0)
        total_height.extend((height_preds == new_gt_height_tensor).detach().cpu().numpy().tolist())

        if mode == "test":
            # save the prediction results
            test_res["num"][0].append(num_gt.cpu().numpy().tolist())
            test_res["num"][1].append(num_preds.detach().cpu().numpy().tolist())
            test_res["position"][0].append(new_gt_pos_tensor.cpu().numpy().tolist())
            test_res["position"][1].append(pos_preds.detach().cpu().numpy().tolist())
            test_res["height"][0].append(new_gt_height)
            test_res["height"][1].append(new_pred_height)

    final_num_rmse = np.mean(total_number) if total_number else 0.0
    final_pos_rmse = np.mean(total_position) if total_position else 0.0
    final_height_acc = np.sum(total_height) / len(total_height) if total_height else 0.0

    return loss_accum / count, loss_accum_3[0] / count, loss_accum_3[1] / count, loss_accum_3[2] / count, final_num_rmse, final_pos_rmse, final_height_acc, test_res


def train_hct(train_data, val_data, test_data, args):
    # Dataset & Dataloader
    train_loader = DataLoader(train_data, batch_size=args.bs, num_workers=args.num_workers, shuffle=True, collate_fn=collate_hct_ecd)
    val_loader = DataLoader(val_data, batch_size=args.bs, num_workers=args.num_workers, shuffle=False, collate_fn=collate_hct_ecd)
    test_loader = DataLoader(test_data, batch_size=args.bs, num_workers=args.num_workers, shuffle=False, collate_fn=collate_hct_ecd)
        
    model = HCTModel(d_model=args.hidden_dim, n_heads=args.num_heads, num_layers=args.num_layers, proj_dim=args.proj_dim, 
                     chiral_encoder=args.chiral_encoder, use_qr=args.use_qr, ecd=True).to(args.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=args.min_lr)

    best_acc = float('-inf')  # val_
    best_dict = {}
    for epoch in tqdm(range(args.epochs)):#, disable=not sys.stdout.isatty()):
        train_loss, train_loss_num, train_loss_pos, train_loss_height, train_num_rmse, train_pos_rmse, train_height_acc, _ = run_epoch(model, train_loader, args.device,
                                        optimizer, use_orth_reg=args.use_orth_reg, reg_lambda=args.reg_lambda, mode="train")
        val_loss, val_loss_num, val_loss_pos, val_loss_height, val_num_rmse, val_pos_rmse, val_height_acc, _ = run_epoch(model, val_loader, args.device,
                                    use_orth_reg=args.use_orth_reg, reg_lambda=args.reg_lambda, mode="val")
        test_loss, test_loss_num, test_loss_pos, test_loss_height, test_num_rmse, test_pos_rmse, test_height_acc, test_res = run_epoch(model, test_loader, args.device,
                                        use_orth_reg=args.use_orth_reg, reg_lambda=args.reg_lambda, mode="test")
        scheduler.step()
        if val_height_acc > best_acc:
            best_acc = val_height_acc
            best_dict = {
            "epoch": epoch+1,
            "train_loss": train_loss,
            "train_loss_num": train_loss_num,
            "train_loss_pos": train_loss_pos,
            "train_loss_height": train_loss_height,
            "train_num_rmse": train_num_rmse,
            "train_pos_rmse": train_pos_rmse,
            "train_height_acc": train_height_acc,
            "val_loss": val_loss,
            "val_loss_num": val_loss_num,
            "val_loss_pos": val_loss_pos,
            "val_loss_height": val_loss_height,
            "val_num_rmse": val_num_rmse,
            "val_pos_rmse": val_pos_rmse,
            "val_height_acc": val_height_acc,
            "test_loss": test_loss,
            "test_loss_num": test_loss_num,
            "test_loss_pos": test_loss_pos,
            "test_loss_height": test_loss_height,
            "test_num_rmse": test_num_rmse,
            "test_pos_rmse": test_pos_rmse,
            "test_height_acc": test_height_acc,
            "test_res": test_res
        }

        print(f"[Epoch {epoch+1}] Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Test Loss: {test_loss:.4f}")
        print(f"Train Loss Num: {train_loss_num:.4f}, Pos: {train_loss_pos:.4f}, Height: {train_loss_height:.4f} | ")
        print(f"Val Loss Num: {val_loss_num:.4f}, Pos: {val_loss_pos:.4f}, Height: {val_loss_height:.4f} | ")
        print(f"Test Loss Num: {test_loss_num:.4f}, Pos: {test_loss_pos:.4f}, Height: {test_loss_height:.4f} | ")
        print(f"Train Num RMSE: {train_num_rmse:.4f}, Pos RMSE: {train_pos_rmse:.4f}, Height Acc: {train_height_acc:.4f} | ")
        print(f"Val Num RMSE: {val_num_rmse:.4f}, Pos RMSE: {val_pos_rmse:.4f}, Height Acc: {val_height_acc:.4f} | ")
        print(f"Test Num RMSE: {test_num_rmse:.4f}, Pos RMSE: {test_pos_rmse:.4f}, Height Acc: {test_height_acc:.4f} | ")

    with open("/home/data/HCT/res/ecd_axial/"+
              f"{args.random_ratio}_{args.epochs}_{args.bs}_{args.lr}_{args.min_lr}_{args.weight_decay}_{args.use_qr}_{args.use_orth_reg}_"+
              f"{args.hidden_dim}_{args.num_layers}_{args.proj_dim}_{args.chiral_encoder}.pkl", "wb") as f:
        pickle.dump(best_dict, f)

    return model


if __name__ == '__main__':
    args = get_args()


    def process_data(data, random_ratio):
        data_list = []

        # chiral_finder = ChiralFinder([one["rdkit_mol"] for one in data], "molecules")
        # res = chiral_finder.get_axial()
        # with open("/amax/data/HCT/data/hct_ecd_axial_res.pkl", "wb") as f:
        #     pickle.dump(res, f)
        with open("/home/data/HCT/data/hct_ecd_axial_res.pkl", "rb") as f:
            res = pickle.load(f)
        labels = pd.read_excel("/amax/data/chirality/axial/axial_650.xlsx")
        coverage_scores = []
        IoU_scores = []

        def normalize_pairs(pairs):
            """Normalize tuples so (1,2) == (2,1)"""
            return {tuple(sorted(p)) for p in pairs}

        def compute_coverage_and_iou(label, pred):
            # Normalize
            label_set = normalize_pairs(label)
            pred_set = normalize_pairs(pred)

            # Intersection
            inter = label_set & pred_set

            # Union
            union = label_set | pred_set

            # Coverage: predict all label elements
            coverage = 1 if inter == label_set else 0

            # IoU
            iou = len(inter) / len(union) if len(union) > 0 else 1.0

            return coverage, iou


        # random_ratio, select some idx without duplicate
        n_random = int(len(data) * random_ratio)
        random_idx = set(random.sample(range(len(data)), n_random))
        for idx, row in enumerate(tqdm(data, total=len(data), desc="processing data", disable=not sys.stdout.isatty())):
            one_label_id = abs(int(row["id"].split("_")[-1]))
            # one_label = eval(labels["label_expanded"][one_label_id])
            # if not one_label:
            one_label = eval(labels["label"][one_label_id])
            one_pred = res[idx]["chiral axes"]
            # compute coverage and IoU scores
            coverage, iou = compute_coverage_and_iou(one_label, one_pred)
            coverage_scores.append(coverage)
            IoU_scores.append(iou)
            chiral_type = labels["chiral_type"][one_label_id]

            mol = row['rdkit_mol']
            label_num = int(row['peak_num'])
            label_position = np.array(row['peak_position'], dtype=np.int8)
            label_height = np.array(row['peak_height'], dtype=np.int8)
            n_atoms = mol.GetNumAtoms()

            # 1. coordinates
            conf = mol.GetConformer()
            coords = np.array([list(conf.GetAtomPosition(i)) for i in range(n_atoms)], dtype=np.float32)

            # 2. atom_types: 0=chiral_center, 1=chiral_related, 2=non_related
            atom_types = np.ones(n_atoms, dtype=np.int64) * 2
            # chiral_idx = list(set([x for t in res[idx]["chiral axes"] for x in t]))
            if random_ratio > 0.05:
                one_label = res[idx]["chiral axes"]
            chiral_idx = [x for t in one_label for x in t]
            if not chiral_idx:
                print(f"Warning: no chiral center found in molecule {idx}, SMILES: {Chem.MolToSmiles(mol)}")
                # continue

            atom_types[chiral_idx] = 0

            # 3. chiral_related: neighbors of chiral centers

            edge_types = np.full((n_atoms, n_atoms), 0, dtype=np.int64)

            # 4. atom_'onehot': 52, like ChIRo
            atoms = rdkit.Chem.rdchem.Mol.GetAtoms(mol)
            atom_onehot = getNodeFeatures(atoms, mol, False)

            # 5. atom_chiral: 9-dim vector for chiral centers, zeros for others
            atom_chiral = np.zeros((n_atoms, 9), dtype=np.float32)
            tag = False
            for one_axial in tqdm(one_label, desc=f"{idx}", leave=False, disable=not sys.stdout.isatty()):
                for i in range(len(res[idx]["chiral axes"])):
                    # match
                    if set(one_axial) == set(res[idx]["chiral axes"][i]):
                        tag = True
                        # neighbor
                        neighbors = res[idx]["neighbor ids"][i]
                        if isinstance(neighbors[0], list):
                            neighbors = sum(neighbors, [])
                        for nb in neighbors:
                            if nb < n_atoms and atom_types[nb] == 2:  # omit H atom
                                atom_types[nb] = 1
                                for one_chiral_id in one_axial:
                                    edge_types[nb, one_chiral_id] = 1
                                    edge_types[one_chiral_id, nb] = 1
                        # chiral matrix
                        if isinstance(res[idx]["quadrupole matrix"][i][0], list):
                            for k, j in enumerate(res[idx]["chiral axes"][i]):
                                atom_chiral[j] = np.array(res[idx]["quadrupole matrix"][i][k][0].reshape(9), dtype=np.float32)
                        else:
                            for j in res[idx]["chiral axes"][i]:
                                atom_chiral[j] = np.array(res[idx]["quadrupole matrix"][i][0].reshape(9), dtype=np.float32)
                        break
            if not tag:
                print(f"Warning: no chiral axis matched in molecule {idx}, SMILES: {Chem.MolToSmiles(mol)}")
                # 对于每个label，如[(12, 14)]里的12和14（可能有一个或多个），取这几个原子的邻居，取前四个邻居（set，且去掉这几个原子本身），加上这几个原子的中心，构成[0,1,2,3,4]5个原子；assert一下没有四个邻居的情况
                # 取出坐标，计算，a = neigh_cor[1] - neigh_cor[0]，b = neigh_cor[2] - neigh_cor[0]，c = neigh_cor[3] - neigh_cor[4]，作为这几个chiral原子的atom_chiral特征，同时这四个邻居按照前面for nb in neighbors的操作处理
                mol = rdmolops.AddHs(mol, addCoords=True)
                conf = mol.GetConformer()
                coords = np.array([list(conf.GetAtomPosition(i)) for i in range(mol.GetNumAtoms())], dtype=np.float32)
                for one_axial in one_label:
                    neigh_set = set()
                    for one_chiral_id in one_axial:
                        neighbors = mol.GetAtomWithIdx(one_chiral_id).GetNeighbors()
                        neighbors = [x.GetIdx() for x in neighbors if x.GetIdx() not in one_axial]
                        for nb in neighbors:
                            neigh_set.add(nb)
                    neigh_list = list(neigh_set)
                    assert len(neigh_list) >= 3
                    neigh_list = sorted(neigh_list)[:4]
                    # center
                    center = np.mean([coords[x] for x in one_axial], axis=0)
                    neigh_cor = [center] + [coords[x] for x in neigh_list]
                    # 只有三个邻居的情况，补三个邻居均值取反
                    if len(neigh_list) == 3:
                        neigh_cor.append(-np.mean([coords[x] for x in neigh_list], axis=0))
                    a = neigh_cor[1] - neigh_cor[0]
                    b = neigh_cor[2] - neigh_cor[0]
                    c = neigh_cor[3] - neigh_cor[4]
                    # cp_max = np.linalg.norm(np.cross(a, b)) * np.linalg.norm(c)
                    mat = np.array([a, b, c])
                    for j in one_axial:
                        atom_chiral[j] = mat.reshape(9)
                    for nb in neigh_list:
                        if nb < n_atoms and atom_types[nb] == 2:  # omit H atom
                            atom_types[nb] = 1
                            for one_chiral_id in one_axial:
                                edge_types[nb, one_chiral_id] = 1
                                edge_types[one_chiral_id, nb] = 1

            data_list.append({
                'coords': coords,
                'atom_types': atom_types,
                'edge_types': edge_types,
                'atom_onehot': atom_onehot,
                'atom_chiral': atom_chiral,
                'label_num': label_num,
                'label_position': label_position,
                'label_height': label_height,
                "rdkit_mol": mol,
                'chiral_type': chiral_type
            })

        print(f"Average Coverage: {np.mean(coverage_scores)}, average IoU: {np.mean(IoU_scores)}")
        return data_list
    
    params_p = "/amax/data/HCT/data/hct_ecd_axial.pkl"
    data = pd.read_pickle(params_p)

    data = process_data(data, random_ratio=args.random_ratio)


    with open("/amax/data/HCT/data/ecd_axial_index_split.pkl", "rb") as f:
        index_ = pickle.load(f)
        train_index, val_index, test_index = index_["train_index"], index_["val_index"], index_["test_index"]
    
    train_data = [data[i] for i in train_index]
    val_data = [data[i] for i in val_index]
    test_data = [data[i] for i in test_index]

    def temp_remove_others(data):
        new_data = []
        for one in data:
            new_one = {}
            new_one["label_num"] = one["label_num"]
            new_one["label_position"] = one["label_position"]
            new_one["label_height"] = one["label_height"]
            new_one["rdkit_mol"] = one["rdkit_mol"]
            new_data.append(new_one)
        return new_data


    model = train_hct(
        train_data, val_data, test_data, args
        )
    
    checkpoint = {
        "model_state_dict": model.state_dict(),
    }
    torch.save(checkpoint, "/amax/data/HCT/res/ecd_axial/ckpt_epoch_last.pth")