import argparse
import torch
from torch import nn
import numpy as np
from data import get_dataset
import pandas as pd
import pickle as pk
from pymatgen.io.jarvis import JarvisAtomsAdaptor
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from jarvis.core.atoms import Atoms
from torch.utils.data import DataLoader
from tqdm import tqdm
from e3nn.io import CartesianTensor
from pandarallel import pandarallel
from data import get_symmetry_dataset, rm_duplicates
pandarallel.initialize(progress_bar=False)
import gc
from graphs import atoms2graphs, GraphDataset
from utils import get_id_train_val_test
from ceitnet import CEITNet
import matplotlib.pyplot as plt
from e3nn import o3
import pdb
# torch config
torch.set_default_dtype(torch.float32)
import torch
import numpy as np
import random
import os

# Set the random seed for Python, NumPy, and PyTorch
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda") 
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  # if using multi-GPU.
# Configure PyTorch to use deterministic algorithms
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

adptor = JarvisAtomsAdaptor()

diagonal = [0, 4, 8]
off_diagonal = [1, 2, 3, 5, 6, 7]

def structure_to_graphs(
    df: pd.DataFrame,
    use_corrected_structure: bool = False,
    reduce_cell: bool = False,
    cutoff: float = 4.0,
    max_neighbors: int = 16
):
    def atoms_to_graph(p_input):
        """Convert structure dict to DGLGraph."""
        structure = adptor.get_atoms(p_input["structure"])
        return atoms2graphs(
            structure,
            cutoff=cutoff,
            max_neighbors=max_neighbors,
            reduce=reduce_cell,
            equivalent_atoms=p_input['equivalent_atoms'],
            use_canonize=True,
        )
    graphs = df["p_input"].parallel_apply(atoms_to_graph).values
    # graphs = df["p_input"].apply(atoms_to_graph).values
    return graphs


def _sym_dataset_rotations(sym_dataset):
    """Return fractional-coordinate rotation matrices from a spglib symmetry dataset."""
    if hasattr(sym_dataset, "rotations"):
        return np.array(sym_dataset.rotations)
    return np.array(sym_dataset["rotations"])


def get_cartesian_rotations_from_sym_dataset(structure: Structure, sym_dataset) -> np.ndarray:
    """Convert spglib fractional rotations to unique Cartesian rotations."""
    rots = _sym_dataset_rotations(sym_dataset)
    rots = rm_duplicates(rots)
    lat = structure.lattice.matrix.T
    lat_inv = np.linalg.inv(lat)
    cart_rots = np.matmul(lat, np.matmul(rots, lat_inv))
    return cart_rots


def _transform_rank4(R: np.ndarray, T: np.ndarray) -> np.ndarray:
    """Apply a 3x3 linear transform to a rank-4 Cartesian tensor."""
    return np.einsum('ia,jb,kc,ld,abcd->ijkl', R, R, R, R, T)


def _voigt_pairs_ceitnet():
    # Must match ceitnet VoigtBlock convention
    return [(0, 0), (1, 1), (2, 2), (0, 1), (1, 2), (2, 0)]


def _voigt66_to_rank4(C66: np.ndarray) -> np.ndarray:
    """Expand a 6x6 Voigt matrix into a fully symmetric rank-4 tensor (minor+major symmetries)."""
    pairs = _voigt_pairs_ceitnet()
    T = np.zeros((3, 3, 3, 3), dtype=np.float64)
    for I, (a, b) in enumerate(pairs):
        for J, (c, d) in enumerate(pairs):
            v = float(C66[I, J])
            if abs(v) < 1e-15:
                continue
            for (u, v1) in ((a, b), (b, a)):
                for (k, l) in ((c, d), (d, c)):
                    T[u, v1, k, l] = v
                    T[k, l, u, v1] = v
    return T


def _forced_zero_rank4_to_voigt66(forced_zero_3333: np.ndarray) -> np.ndarray:
    """Map (3,3,3,3) forced-zero mask to Voigt (6,6) using ceitnet's VoigtBlock convention."""
    pairs = _voigt_pairs_ceitnet()
    out = np.zeros((6, 6), dtype=bool)
    for I, (a, b) in enumerate(pairs):
        for J, (c, d) in enumerate(pairs):
            out[I, J] = bool(forced_zero_3333[a, b, c, d])
    return out


def infer_forced_zero_mask_symmetric_rank4_voigt(
    cart_rots: np.ndarray,
    rcond: float = 1e-10,
    zero_tol: float = 1e-8,
) -> np.ndarray:
    """Infer symmetry-forced zeros for an elastic stiffness tensor in Voigt (6,6).

    We build a basis in the 6x6 symmetric Voigt space (21 basis tensors), expand to rank-4,
    group-average under the symmetry operations, then find entries that are always zero.

    Returns:
        forced_zero_mask_66: (6,6) bool array.
    """
    if cart_rots is None or len(cart_rots) == 0:
        return np.zeros((6, 6), dtype=bool)

    basis_rank4 = []
    for I in range(6):
        for J in range(I, 6):
            C66 = np.zeros((6, 6), dtype=np.float64)
            C66[I, J] = 1.0
            C66[J, I] = 1.0
            basis_rank4.append(_voigt66_to_rank4(C66))

    n = float(len(cart_rots))
    cols = []
    for B in basis_rank4:
        avg = np.zeros((3, 3, 3, 3), dtype=np.float64)
        for R in cart_rots:
            avg += _transform_rank4(R, B)
        avg /= n
        cols.append(avg.reshape(-1))  # 81

    M = np.stack(cols, axis=1)  # (81, 21)
    U, S, _ = np.linalg.svd(M, full_matrices=False)
    if S.size == 0:
        return np.zeros((6, 6), dtype=bool)

    smax = S[0]
    keep = S > (rcond * smax)
    if not np.any(keep):
        return np.zeros((6, 6), dtype=bool)

    inv_basis = U[:, keep]  # (81, k)
    forced_zero = np.max(np.abs(inv_basis), axis=1) < zero_tol
    forced_zero_3333 = forced_zero.reshape(3, 3, 3, 3)
    return _forced_zero_rank4_to_voigt66(forced_zero_3333)

# def structure_to_graphs(
#     df: pd.DataFrame,
#     use_corrected_structure: bool = False,
#     reduce_cell: bool = False,
#     cutoff: float = 6.0,
#     max_neighbors: int = 16
# ):
#     def atoms_to_graph(p_input):
#         """Convert structure dict to DGLGraph."""
#         structure = adptor.get_atoms(p_input["structure"])
#         return atoms2graphs_etgnn(
#             structure,
#             cutoff=cutoff,
#         )
#     graphs = df["p_input"].parallel_apply(atoms_to_graph).values
#     # graphs = df["p_input"].apply(atoms_to_graph).values
#     return graphs

class PolynomialLRDecay(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, max_iters, start_lr, end_lr, power=1, last_epoch=-1):
        self.max_iters = max_iters
        self.start_lr = start_lr
        self.end_lr = end_lr
        self.power = power
        self.last_iter = 0  # Custom attribute to keep track of last iteration count
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [
            (self.start_lr - self.end_lr) * 
            ((1 - self.last_iter / self.max_iters) ** self.power) + self.end_lr 
            for base_lr in self.base_lrs
        ]

    def step(self, epoch=None):
        self.last_iter += 1  # Increment the last iteration count
        return super().step(epoch)

def group_decay(model):
    """Omit weight decay from bias and batchnorm params."""
    decay, no_decay = [], []

    for name, p in model.named_parameters():
        if "bias" in name or "bn" in name or "norm" in name:
            no_decay.append(p)
        else:
            decay.append(p)

    return [
        {"params": decay},
        {"params": no_decay, "weight_decay": 0},
    ]


def get_pyg_dataset(data, target, reduce_cell=False):
    df_dataset = pd.DataFrame(data)
    g_dataset = structure_to_graphs(df_dataset, reduce_cell=reduce_cell)
    pyg_dataset = GraphDataset(df=df_dataset,graphs=g_dataset, target=target)
    return pyg_dataset


def test(model, args):
    # load the dataset
    if args.load_preprocessed:
        print("load preprocessed dataset ...")
    dataset_sym = get_dataset(dataset_name=args.target,use_corrected_structure=args.use_corrected_structure,load_preprocessed=args.load_preprocessed)
    id_train, id_val, id_test = get_id_train_val_test(
            total_size=len(dataset_sym),
            split_seed=args.split_seed,
            train_ratio=args.train_ratio,
            val_ratio=args.val_ratio,
            test_ratio=args.test_ratio,
            keep_data_order=False,
        )
    dataset_train = [dataset_sym[x] for x in id_train]
    seen_ele=np.zeros([120])
    for itm in dataset_train:
        elems = itm['structure'].atomic_numbers
        for je in range(len(elems)):
            if seen_ele[elems[je]] < 1e-5:
                seen_ele[elems[je]] = 1.0
    
    unseen_list = []
    for i in range(120):
        if seen_ele[i] < 1e-5:
            unseen_list.append(i)
    print("unseen elements:", unseen_list)
    dataset_test = [dataset_sym[x] for x in id_test]
    
    pyg_dataset_test = get_pyg_dataset(dataset_test, args.target)

    # form dataloaders
    collate_fn = pyg_dataset_test.collate

    test_loader = DataLoader(
        pyg_dataset_test,
        batch_size=1,
        shuffle=False,
        collate_fn=collate_fn,
        drop_last=False,
        num_workers=4,
        pin_memory=True,
    )
    print("n_test:", len(test_loader.dataset))

    # set up training configs
    model.to(device)

    # evaluation and store the model
    model.eval()
    # store the label and prediction pairs
    cubic_label = [] # space group 195 <= i <= 230
    cubic_output = []
    cubic_ideal = []

    hexa_label = [] # space group 143 <= i <= 194
    hexa_output = []
    hexa_ideal = []


    trig_label = [] # space group 143 <= i <= 194
    trig_output = []
    trig_ideal = []

    tetr_label = [] # space group 75 <= i <= 142
    tetr_output = []
    tetr_ideal = []

    orth_label = [] # space group 16 <= i <= 74
    orth_output = []
    orth_ideal = []

    mono_label = [] # space group 3 <= i <= 15
    mono_output = []
    mono_ideal = []

    tric_label = [] # space group 1 <= i <= 2
    tric_output = []
    tric_ideal = []

    i = 0
    mae_list =[]
    frob_list = []
    percen_list = []
    out_list = []
    
    for data in tqdm(test_loader):
        structure, mask, equality, labels = data
        structure, mask, equality, labels = structure.to(device), mask.to(device), equality.to(device), labels.to(device)
            
        if args.model in ["comformer", "ceitnet", "geoctp"]:
            outputs = model(structure, mask, equality).view(-1).cpu().detach()
        elif args.model == "megnet":
            outputs = model(structure,test=True).view(-1).cpu().detach() 
        else:
            outputs = model(structure).view(-1).cpu().detach()
            if args.model == "etgnn":
                # Keep consistent with training: flatten labels to 36-dim Voigt vector.
                labels = labels.view(labels.size(0), -1)

        # Optional: Apply group-average-derived strict zero mask to ceitnet predictions BEFORE evaluation.
        if args.model == "ceitnet" and getattr(args, "ceitnet_strict_zero_mask", False):
            try:
                pmg_structure = dataset_test[i]["structure"]
                sym_dataset = dataset_test[i]["sym_dataset"]
                cart_rots = get_cartesian_rotations_from_sym_dataset(pmg_structure, sym_dataset)
                forced_zero_66 = infer_forced_zero_mask_symmetric_rank4_voigt(
                    cart_rots,
                    rcond=getattr(args, "ceitnet_zero_mask_rcond", 1e-10),
                    zero_tol=getattr(args, "ceitnet_zero_mask_tol", 1e-8),
                )
                forced_zero_t = torch.tensor(forced_zero_66, dtype=torch.bool)
                out_mat = outputs.view(6, 6).clone()
                out_mat[forced_zero_t] = 0.0
                outputs = out_mat.reshape(-1)
            except Exception:
                pass

        # Optional: Force ceitnet predictions to zero on entries marked as zero by ideal_matrix (matches Zero Error check).
        if args.model == "ceitnet" and getattr(args, "zero_mask", False):
            ideal = torch.tensor(dataset_test[i]["ideal_matrix"])
            if ideal.numel() == outputs.numel():
                zm = ideal.view(-1).abs() < 1.0
                outputs = outputs.clone()
                outputs[zm] = 0.0


        out_list.append(outputs)

        labels = labels.cpu().view(-1)
        mae_list.append(abs(outputs - labels).view(-1).mean())
        frob_ = ((labels.view(-1) - outputs.view(-1)) ** 2).sum() ** 0.5
        frob_norm = (labels.view(-1) ** 2).sum() ** 0.5
        frob_list.append(frob_)
        percen_list.append(frob_/(frob_norm+1e-5))
        space_g = dataset_test[i]['sym_dataset'].number
        if space_g >= 195:
            cubic_label.append(labels)
            cubic_output.append(outputs)
            cubic_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        elif space_g >= 168:
            hexa_label.append(labels)
            hexa_output.append(outputs)
            hexa_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        elif space_g >= 143:
            trig_label.append(labels)
            trig_output.append(outputs)
            trig_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        elif space_g >= 75:
            tetr_label.append(labels)
            tetr_output.append(outputs)
            tetr_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        elif space_g >= 16:
            orth_label.append(labels)
            orth_output.append(outputs)
            orth_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        elif space_g >= 3:
            mono_label.append(labels)
            mono_output.append(outputs)
            mono_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        else:
            tric_label.append(labels)
            tric_output.append(outputs)
            tric_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))

        i += 1
    
    print("MAE ", np.mean(mae_list))
    print("M_Frob", np.mean(frob_list))
    percen_list = np.array(percen_list)
    print("EwT 25", np.sum(percen_list < 0.25) / percen_list.shape[0])
    print("EwT 10", np.sum(percen_list < 0.1) / percen_list.shape[0])
    print("EwT 5", np.sum(percen_list < 0.05) / percen_list.shape[0])
    print("EwT 2", np.sum(percen_list < 0.02) / percen_list.shape[0])
    
    # evaluation for cubic system
    print("total number of cubic systems", len(cubic_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0

    for i in range(len(cubic_label)):
        label = cubic_label[i].view(-1)
        pred = cubic_output[i].view(-1)
        ideal = cubic_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-5).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-5).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(label[px] - label[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        flag = True
                        break

        if flag: label_equi_error += 1
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        flag = True
                        break

        if flag: pred_equi_error += 1
        


    # CUBIC label errors
    print("CUBIC systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(cubic_label))

    # evaluation for Tetragonal system
    print("total number of Tetragonal systems", len(tetr_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(tetr_label)):
        label = tetr_label[i].view(-1)
        pred = tetr_output[i].view(-1)
        ideal = tetr_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-5).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-5).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(label[px] - label[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        flag = True
                        break

        if flag: label_equi_error += 1
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        flag = True
                        break

        if flag: pred_equi_error += 1
        

    # label errors
    print("Tetragonal systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(tetr_label))

    # evaluation for hexagonal system
    print("total number of hexagonal systems", len(hexa_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(hexa_label)):
        label = hexa_label[i].view(-1)
        pred = hexa_output[i].view(-1)
        ideal = hexa_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-5).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-5).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(label[px] - label[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        flag = True
                        break

        if flag: label_equi_error += 1
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        flag = True
                        break

        if flag: pred_equi_error += 1
        
        

    # label errors
    print("Hexagonal systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(hexa_label))


    # evaluation for trig system
    print("total number of trigonal systems", len(trig_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(trig_label)):
        label = trig_label[i].view(-1)
        pred = trig_output[i].view(-1)
        ideal = trig_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-5).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-5).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(label[px] - label[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        flag = True
                        break

        if flag: label_equi_error += 1
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        flag = True
                        break

        if flag: pred_equi_error += 1
        
        

    # label errors
    print("Trigonal systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(trig_label))


    # evaluation for Orthorhombic system
    print("total number of Orthorhombic systems", len(orth_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(orth_label)):
        label = orth_label[i].view(-1)
        pred = orth_output[i].view(-1)
        ideal = orth_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-5).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-5).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(label[px] - label[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        flag = True
                        break

        if flag: label_equi_error += 1
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        flag = True
                        break

        if flag: pred_equi_error += 1
        
        

    # label errors
    print("Orthorhombic systems: Label symmetry error - Zero Error", label_sym_error, "Inequality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Inequality Error", pred_equi_error, "Fnorm", F_error/len(orth_label))

    # evaluation for Orthorhombic system
    print("total number of Monoclinic systems", len(mono_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(mono_label)):
        label = mono_label[i].view(-1)
        pred = mono_output[i].view(-1)
        ideal = mono_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-5).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-5).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(label[px] - label[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        flag = True
                        break

        if flag: label_equi_error += 1
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        flag = True
                        break

        if flag: pred_equi_error += 1
        
    

    # label errors
    print("Monoclinic systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(mono_label))

    # evaluation for Triclinic system
    print("total number of Triclinic systems", len(tric_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(tric_label)):
        label = tric_label[i].view(-1)
        pred = tric_output[i].view(-1)
        ideal = tric_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-5).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-5).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(label[px] - label[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        flag = True
                        break

        if flag: label_equi_error += 1
        flag = False
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-5:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        flag = True
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-5 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        flag = True
                        break

        if flag: pred_equi_error += 1
        
    

    # label errors
    print("Triclinic systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(tric_label))
    
    return



def main():
    parser = argparse.ArgumentParser(description='Training script')

    # Define command-line arguments
    # training parameters
    parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size of training and evaluating')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-05, help='weight decay')
    parser.add_argument('--loss', type=str, default='huber', help='mse or l1 or huber')
    parser.add_argument('--model', type=str, default='ceitnet', help='ceitnet')
    parser.add_argument('--name', type=str, default='test', help='name of project for storage')
    parser.add_argument('--reduce_cell', type=bool, default=False, help='reduce the cell into irreducible atom sets')
    # dataset parameters
    parser.add_argument('--split_seed', type=int, default=32, help='the random seed of spliting data')
    parser.add_argument('--train_ratio', type=float, default=0.8, help='training ratio used in data split')
    parser.add_argument('--val_ratio', type=float, default=0.1, help='evaluate ratio used in data split')
    parser.add_argument('--test_ratio', type=float, default=0.1, help='test ratio used in data split')
    parser.add_argument('--target', type=str, default='elastic', help='dielectric, piezoelectric, or elastic')
    parser.add_argument('--threshold', type=float, default=100., help='threshold to remove samples')
    parser.add_argument('--use_corrected_structure', type=bool, default=True, help='correct input structure or not')
    parser.add_argument('--load_preprocessed', type=bool, default=True, help='load previous processed dataset')
    parser.add_argument('--ckpt_path', type=str, default='repro_data/pretrained_ckpts/elast.pt', help='Path to a checkpoint .pt file (overrides ckpt_kind)')
    parser.add_argument('--zero_mask', action='store_true', help='')



    args = parser.parse_args()

    print('Training settings:')
    print(f'  Epochs: {args.epochs}')
    print(f'  Learning rate: {args.learning_rate}')
    print(args)
    torch.manual_seed(args.split_seed)
    torch.cuda.manual_seed_all(args.split_seed)
    # load the model
    if args.model == "ceitnet":
        model = CEITNet(args)

    state_dict = torch.load(args.ckpt_path)

    model.load_state_dict(state_dict)

    test(model, args)

if __name__ == "__main__":
    main()