import argparse
import torch
from torch import nn
import numpy as np
from data import get_dataset, rm_duplicates
import pandas as pd
from pymatgen.io.jarvis import JarvisAtomsAdaptor
from torch.utils.data import DataLoader
from tqdm import tqdm
from pandarallel import pandarallel
pandarallel.initialize(progress_bar=False)

from graphs import atoms2graphs, GraphDataset
from utils import get_id_train_val_test
from ceitnet import CEITNet

torch.set_default_dtype(torch.float32)
import random
import os

import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

# Set the random seed for Python, NumPy, and PyTorch
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)  # if using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

adptor = JarvisAtomsAdaptor()

def structure_to_graphs(
    df: pd.DataFrame,
    use_corrected_structure: bool = False,
    reduce_cell: bool = False,
    cutoff: float = 4.0,
    max_neighbors: int = 16
):
    def atoms_to_graph(p_input):
        """Convert structure dict to DGLGraph."""
        structure = adptor.get_atoms(p_input["structure"])
        return atoms2graphs(
            structure,
            cutoff=cutoff,
            max_neighbors=max_neighbors,
            reduce=reduce_cell,
            equivalent_atoms=p_input['equivalent_atoms'],
            use_canonize=True,
        )
    graphs = df["p_input"].parallel_apply(atoms_to_graph).values
    # graphs = df["p_input"].apply(atoms_to_graph).values
    return graphs

def get_pyg_dataset(data, target, reduce_cell=False):
    df_dataset = pd.DataFrame(data)
    g_dataset = structure_to_graphs(df_dataset, reduce_cell=reduce_cell)
    pyg_dataset = GraphDataset(df=df_dataset,graphs=g_dataset, target=target)
    return pyg_dataset


def get_cartesian_rotations_from_sym_dataset(structure, sym_dataset):
    if sym_dataset is None:
        return np.empty((0, 3, 3), dtype=np.float64)
    rots = np.array(sym_dataset.get("rotations", []))
    if rots.size == 0:
        return np.empty((0, 3, 3), dtype=np.float64)
    rots = rm_duplicates(rots)
    lat = structure.lattice.matrix.T
    lat_inv = np.linalg.inv(lat)
    cart_rots = np.matmul(lat, np.matmul(rots, lat_inv))
    return cart_rots


def infer_forced_zero_mask_symmetric_rank2(
    cart_rots: np.ndarray,
    rcond: float = 1e-10,
    zero_tol: float = 1e-8,
):
    if cart_rots is None or len(cart_rots) == 0:
        return np.zeros((3, 3), dtype=bool)

    # Symmetric basis (6 matrices): xx, yy, zz, yz, xz, xy (with symmetric off-diagonals).
    basis = []
    for i in range(3):
        B = np.zeros((3, 3), dtype=np.float64)
        B[i, i] = 1.0
        basis.append(B)
    for (i, j) in [(1, 2), (0, 2), (0, 1)]:
        B = np.zeros((3, 3), dtype=np.float64)
        B[i, j] = 1.0
        B[j, i] = 1.0
        basis.append(B)

    n = float(len(cart_rots))
    cols = []
    for B in basis:
        avg = np.zeros((3, 3), dtype=np.float64)
        for R in cart_rots:
            avg += R @ B @ R.T
        avg /= n
        cols.append(avg.reshape(-1))  # row-major flatten (9,)

    M = np.stack(cols, axis=1)  # (9, 6)
    U, S, _ = np.linalg.svd(M, full_matrices=False)
    if S.size == 0:
        return np.zeros((3, 3), dtype=bool)

    smax = S[0]
    keep = S > (rcond * smax)
    if not np.any(keep):
        # Degenerate case: avoid masking everything.
        return np.zeros((3, 3), dtype=bool)

    inv_basis = U[:, keep]  # (9, k)
    forced_zero = np.max(np.abs(inv_basis), axis=1) < zero_tol
    return forced_zero.reshape(3, 3)


def test(model, args):
    if args.load_preprocessed:
        print("load preprocessed dataset ...")

    dataset_sym = get_dataset(
        dataset_name=args.target,
        use_corrected_structure=args.use_corrected_structure,
        load_preprocessed=args.load_preprocessed,
    )

    id_train, id_val, id_test = get_id_train_val_test(
        total_size=len(dataset_sym),
        split_seed=args.split_seed,
        train_ratio=args.train_ratio,
        val_ratio=args.val_ratio,
        test_ratio=args.test_ratio,
        keep_data_order=False,
    )

    dataset_train = [dataset_sym[x] for x in id_train]
    dataset_test = [dataset_sym[x] for x in id_test]

    seen_ele = np.zeros(120, dtype=bool)
    for itm in dataset_train:
        elems = np.asarray(itm["structure"].atomic_numbers, dtype=np.int64)
        elems = elems[(elems >= 0) & (elems < 120)]
        seen_ele[elems] = True

    unseen_list = np.where(~seen_ele)[0].tolist()
    print("unseen elements:", unseen_list)

    pyg_dataset_test = get_pyg_dataset(dataset_test, args.target)
    test_loader = DataLoader(
        pyg_dataset_test,
        batch_size=1,
        shuffle=False,
        collate_fn=pyg_dataset_test.collate,
        drop_last=False,
        num_workers=4,
        pin_memory=True,
    )
    print("n_test:", len(test_loader.dataset))

    def _to_tensor(x):
        # ideal_matrix can be a NumPy array or a torch.Tensor
        if torch.is_tensor(x):
            return x
        return torch.as_tensor(x)

    def _bucket_name(space_g: int) -> str:
        if space_g >= 195:
            return "CUBIC"
        if space_g >= 143:
            return "HEXA"
        if space_g >= 75:
            return "TETR"
        if space_g >= 16:
            return "ORTH"
        if space_g >= 3:
            return "MONO"
        return "TRIC"

    def _has_equality_violation(vec, ideal_vec, ratio_tol=1e-5, val_tol=1e-4) -> bool:
        """
        vec, ideal_vec: 1D tensors after masking (abs(ideal) > 1.0)
        Violation if exists i<j s.t. ideal_i/ideal_j ~ 1 but vec_i != vec_j.
        """
        n = int(vec.numel())
        if n <= 1:
            return False
        for i in range(n - 1):
            for j in range(i + 1, n):
                if (ideal_vec[i] / ideal_vec[j] - 1.0).abs() < ratio_tol:
                    if (vec[i] - vec[j]).abs() > val_tol:
                        return True
        return False

    def evaluate_system(name, labels_list, preds_list, ideals_list):
        n = len(labels_list)
        print(f"total number of {name} systems", n)
        if n == 0:
            print(f"{name} systems: (no samples)")
            return

        label_sym_error = 0
        label_equi_error = 0
        pred_sym_error = 0
        pred_equi_error = 0
        F_error = 0.0

        for L, P, I in zip(labels_list, preds_list, ideals_list):
            # flatten
            label = L.view(-1)
            pred = P.view(-1)
            ideal = I.view(-1)

            # Frobenius
            F_error += torch.norm(label - pred, p=2).item()

            # zero-entry symmetry check (keep your convention: abs(ideal) < 1.0)
            zero_mask = ideal.abs() < 1.0
            if (label[zero_mask].abs() > 1e-5).any().item():
                label_sym_error += 1
            if (pred[zero_mask].abs() > 1e-5).any().item():
                pred_sym_error += 1

            # equality check (keep your convention: abs(ideal) > 1.0)
            eq_mask = ideal.abs() > 1.0
            label_eq = label[eq_mask]
            pred_eq = pred[eq_mask]
            ideal_eq = ideal[eq_mask]

            if _has_equality_violation(label_eq, ideal_eq):
                label_equi_error += 1
            if _has_equality_violation(pred_eq, ideal_eq):
                pred_equi_error += 1

        print(
            f"{name} systems: Label symmetry error - Zero Error",
            label_sym_error,
            "Equality Error",
            label_equi_error,
        )
        print(
            "Prediction error - Zero Error",
            pred_sym_error,
            "Equality Error",
            pred_equi_error,
            "Fnorm",
            F_error / n,
        )

    model.to(device)
    model.eval()

    buckets = {
        "CUBIC": {"labels": [], "preds": [], "ideals": []},
        "HEXA": {"labels": [], "preds": [], "ideals": []},
        "TETR": {"labels": [], "preds": [], "ideals": []},
        "ORTH": {"labels": [], "preds": [], "ideals": []},
        "MONO": {"labels": [], "preds": [], "ideals": []},
        "TRIC": {"labels": [], "preds": [], "ideals": []},
    }

    mae_list = []
    frob_list = []
    percen_list = []
    error_eT = []
    forced_zero_counts = []

    with torch.no_grad():
        for idx, data in enumerate(tqdm(test_loader)):
            structure, mask, equality, labels, rot_list = data
            structure = structure.to(device)
            mask = mask.to(device)
            equality = equality.to(device)
            labels = labels.to(device)

            outputs = model(structure, mask, equality).detach().cpu().view(3, 3)
            labels_cpu = labels.detach().cpu().view(3, 3)

            if getattr(args, "zero_mask", False):
                pmg_structure = dataset_test[idx]["structure"]
                sym_dataset = dataset_test[idx].get("sym_dataset", None)
                cart_rots = get_cartesian_rotations_from_sym_dataset(pmg_structure, sym_dataset)
                forced_zero_mask = infer_forced_zero_mask_symmetric_rank2(cart_rots)
                forced_zero_mask_t = torch.as_tensor(forced_zero_mask, dtype=torch.bool)

                ideal_mat_tmp = _to_tensor(dataset_test[idx]["ideal_matrix"]).detach().cpu().view(3, 3)
                ideal_zero_mask = ideal_mat_tmp.abs() < 1.0

                # Combine to be robust: if inferred mask is empty, ideal mask still enforces symmetry zeros.
                combined_mask = forced_zero_mask_t | ideal_zero_mask
                outputs = outputs.clone()
                outputs[combined_mask] = 0.0
                forced_zero_counts.append(int(combined_mask.sum().item()))

            mae_list.append((outputs - labels_cpu).abs().mean().item())
            frob_ = torch.norm(labels_cpu.view(-1) - outputs.view(-1), p=2).item()
            frob_norm = torch.norm(labels_cpu.view(-1), p=2).item()
            frob_list.append(frob_)
            percen_list.append(frob_ / frob_norm if frob_norm > 0 else float("nan"))

            space_g = int(dataset_test[idx]["sym_dataset"]["number"])
            ideal_mat = _to_tensor(dataset_test[idx]["ideal_matrix"]).detach().cpu().view(3, 3)

            key = _bucket_name(space_g)
            buckets[key]["labels"].append(labels_cpu)
            buckets[key]["preds"].append(outputs)
            buckets[key]["ideals"].append(ideal_mat)


    if len(error_eT) > 0:
        print("diff in eT", float(np.mean(error_eT)))
    else:
        print("diff in eT", 0.0)

    print("MAE ", float(np.mean(mae_list)) if len(mae_list) else 0.0)
    print("M_Frob", float(np.mean(frob_list)) if len(frob_list) else 0.0)

    percen_arr = np.asarray(percen_list, dtype=np.float64)
    valid = np.isfinite(percen_arr)
    if valid.any():
        percen_arr = percen_arr[valid]
        print("EwT 25", float(np.mean(percen_arr < 0.25)))
        print("EwT 10", float(np.mean(percen_arr < 0.10)))
        print("EwT 5",  float(np.mean(percen_arr < 0.05)))
        print("EwT 2",  float(np.mean(percen_arr < 0.02)))
    else:
        print("EwT 25", 0.0)
        print("EwT 10", 0.0)
        print("EwT 5",  0.0)
        print("EwT 2",  0.0)

    evaluate_system("CUBIC",       buckets["CUBIC"]["labels"], buckets["CUBIC"]["preds"], buckets["CUBIC"]["ideals"])
    evaluate_system("Tetragonal",  buckets["TETR"]["labels"],  buckets["TETR"]["preds"],  buckets["TETR"]["ideals"])
    evaluate_system("Hexagonal",   buckets["HEXA"]["labels"],  buckets["HEXA"]["preds"],  buckets["HEXA"]["ideals"])
    evaluate_system("Orthorhombic",buckets["ORTH"]["labels"],  buckets["ORTH"]["preds"],  buckets["ORTH"]["ideals"])
    evaluate_system("Monoclinic",  buckets["MONO"]["labels"],  buckets["MONO"]["preds"],  buckets["MONO"]["ideals"])
    evaluate_system("Triclinic",   buckets["TRIC"]["labels"],  buckets["TRIC"]["preds"],  buckets["TRIC"]["ideals"])

    return
def main():
    parser = argparse.ArgumentParser(description='Test script')

    # Define command-line arguments
    parser.add_argument('--model', type=str, default='ceitnet', help='ceitnet')
    parser.add_argument('--ckpt_path', type=str, default='repro_data/pretrained_ckpts/diele.pt', help='path to model checkpoint')
    parser.add_argument('--reduce_cell', type=bool, default=False, help='reduce the cell into irreducible atom sets, not used')
    # dataset parameters
    parser.add_argument('--split_seed', type=int, default=32, help='the random seed of spliting data')
    parser.add_argument('--train_ratio', type=float, default=0.8, help='training ratio used in data split')
    parser.add_argument('--val_ratio', type=float, default=0.1, help='evaluate ratio used in data split')
    parser.add_argument('--test_ratio', type=float, default=0.1, help='test ratio used in data split')
    parser.add_argument('--target', type=str, default='dielectric', help='dielectric, piezoelectric, or elastic')
    parser.add_argument('--use_corrected_structure', type=bool, default=True, help='correct input structure or not')
    parser.add_argument('--load_preprocessed', type=bool, default=True, help='load previous processed dataset')
    parser.add_argument('--zero_mask', action='store_true', help='force zero entries during evaluation')

    args = parser.parse_args()

    print('Test settings:')
    print(args)
    torch.manual_seed(args.split_seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(args.split_seed)
    # load the model
    model = CEITNet(args)
        
    state_dict = torch.load(args.ckpt_path, map_location=device)
    # Load the state dictionary into the model
    model.load_state_dict(state_dict)

    test(model, args)

if __name__ == "__main__":
    main()
