import argparse
import torch
from torch import nn
import numpy as np
from data import get_dataset
import pandas as pd
import pickle as pk
from pymatgen.io.jarvis import JarvisAtomsAdaptor
from pymatgen.core.lattice import Lattice
from pymatgen.core.structure import Structure
from jarvis.core.atoms import Atoms
from torch.utils.data import DataLoader
from tqdm import tqdm
from e3nn.io import CartesianTensor
from pandarallel import pandarallel
from data import get_symmetry_dataset, rm_duplicates
pandarallel.initialize(progress_bar=False)

from graphs import atoms2graphs, atoms2graphs_etgnn, GraphDataset
from utils import get_id_train_val_test
from ceitnet import CEITNet
import matplotlib.pyplot as plt
from e3nn import o3
import pdb
import json
from pathlib import Path
import glob
# torch config
torch.set_default_dtype(torch.float32)
import torch
import numpy as np
import random
import os

# Set the random seed for Python, NumPy, and PyTorch
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda")

torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)  # if using multi-GPU.
# Configure PyTorch to use deterministic algorithms
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

adptor = JarvisAtomsAdaptor()

diagonal = [0, 4, 8]
off_diagonal = [1, 2, 3, 5, 6, 7]

def structure_to_graphs(
    df: pd.DataFrame,
    use_corrected_structure: bool = False,
    reduce_cell: bool = False,
    cutoff: float = 4.0,
    max_neighbors: int = 64
):
    def atoms_to_graph(p_input):
        """Convert structure dict to DGLGraph."""
        structure = adptor.get_atoms(p_input["structure"])
        return atoms2graphs(
            structure,
            cutoff=cutoff,
            max_neighbors=max_neighbors,
            reduce=reduce_cell,
            equivalent_atoms=p_input['equivalent_atoms'],
            use_canonize=True,
        )
    graphs = df["p_input"].parallel_apply(atoms_to_graph).values
    # graphs = df["p_input"].apply(atoms_to_graph).values
    return graphs

def _sym_dataset_rotations(sym_dataset):
    """Return fractional-coordinate rotation matrices from a spglib symmetry dataset."""
    # spglib can return either a dict-like or an object with attributes
    if hasattr(sym_dataset, "rotations"):
        return np.array(sym_dataset.rotations)
    return np.array(sym_dataset["rotations"])


def get_cartesian_rotations_from_sym_dataset(structure: Structure, sym_dataset) -> np.ndarray:
    rots = _sym_dataset_rotations(sym_dataset)
    rots = rm_duplicates(rots)
    lat = structure.lattice.matrix.T
    lat_inv = np.linalg.inv(lat)
    cart_rots = np.matmul(lat, np.matmul(rots, lat_inv))
    return cart_rots


def _transform_rank3(R: np.ndarray, T: np.ndarray) -> np.ndarray:
    return np.einsum('ia,jb,kc,abc->ijk', R, R, R, T)


def infer_forced_zero_mask_symmetric_rank3_jk(
    cart_rots: np.ndarray,
    rcond: float = 1e-10,
    zero_tol: float = 1e-8,
) -> np.ndarray:
    if cart_rots is None or len(cart_rots) == 0:
        return np.zeros((3, 3, 3), dtype=bool)

    # Independent basis: i in [0..2], (j,k) with j<=k, enforce symmetry by mirroring.
    basis = []
    for i in range(3):
        for j in range(3):
            for k in range(j, 3):
                B = np.zeros((3, 3, 3), dtype=np.float64)
                B[i, j, k] = 1.0
                if k != j:
                    B[i, k, j] = 1.0
                basis.append(B)

    n = float(len(cart_rots))
    cols = []
    for B in basis:
        avg = np.zeros((3, 3, 3), dtype=np.float64)
        for R in cart_rots:
            avg += _transform_rank3(R, B)
        avg /= n
        cols.append(avg.reshape(-1))

    M = np.stack(cols, axis=1)  # (27, 18)
    U, S, _ = np.linalg.svd(M, full_matrices=False)
    if S.size == 0:
        return np.zeros((3, 3, 3), dtype=bool)

    smax = S[0]
    keep = S > (rcond * smax)
    if not np.any(keep):
        return np.zeros((3, 3, 3), dtype=bool)

    inv_basis = U[:, keep]  # (27, k)
    forced_zero = np.max(np.abs(inv_basis), axis=1) < zero_tol
    return forced_zero.reshape(3, 3, 3)


def forced_zero_mask_rank3_to_voigt36(forced_zero_mask_333: np.ndarray) -> np.ndarray:
    """Map a (3,3,3) forced-zero mask to Voigt (3,6) using ceitnet's to_voigt convention."""
    mapping = [(0, 0), (1, 1), (2, 2), (0, 1), (1, 2), (2, 0)]  # xx,yy,zz,xy,yz,zx
    out = np.zeros((3, 6), dtype=bool)
    for i in range(3):
        for J, (j, k) in enumerate(mapping):
            out[i, J] = bool(forced_zero_mask_333[i, j, k])
    return out

class PolynomialLRDecay(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, max_iters, start_lr, end_lr, power=1, last_epoch=-1):
        self.max_iters = max_iters
        self.start_lr = start_lr
        self.end_lr = end_lr
        self.power = power
        self.last_iter = 0  # Custom attribute to keep track of last iteration count
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        return [
            (self.start_lr - self.end_lr) * 
            ((1 - self.last_iter / self.max_iters) ** self.power) + self.end_lr 
            for base_lr in self.base_lrs
        ]

    def step(self, epoch=None):
        self.last_iter += 1  # Increment the last iteration count
        return super().step(epoch)

def get_pyg_dataset(data, target, reduce_cell=False):
    df_dataset = pd.DataFrame(data)
    g_dataset = structure_to_graphs(df_dataset, reduce_cell=reduce_cell)
    pyg_dataset = GraphDataset(df=df_dataset,graphs=g_dataset, target=target)
    return pyg_dataset

def test(model, args):
    # load the dataset
    if args.load_preprocessed:
        print("load preprocessed dataset ...")
    dataset_sym = get_dataset(dataset_name=args.target,use_corrected_structure=args.use_corrected_structure,load_preprocessed=args.load_preprocessed)

    id_train, id_val, id_test = get_id_train_val_test(
            total_size=len(dataset_sym),
            split_seed=args.split_seed,
            train_ratio=args.train_ratio,
            val_ratio=args.val_ratio,
            test_ratio=args.test_ratio,
            keep_data_order=False,
        )
    dataset_train = [dataset_sym[x] for x in id_train]
    seen_ele=np.zeros([120])
    for itm in dataset_train:
        elems = itm['structure'].atomic_numbers
        for je in range(len(elems)):
            if seen_ele[elems[je]] < 1e-4:
                seen_ele[elems[je]] = 1.0
    
    unseen_list = []
    for i in range(120):
        if seen_ele[i] < 1e-4:
            unseen_list.append(i)
    print("unseen elements:", unseen_list)
    dataset_test = [dataset_sym[x] for x in id_test]
    
    pyg_dataset_test = get_pyg_dataset(dataset_test, args.target)

    # form dataloaders
    collate_fn = pyg_dataset_test.collate

    test_loader = DataLoader(
        pyg_dataset_test,
        batch_size=1,
        shuffle=False,
        collate_fn=collate_fn,
        drop_last=False,
        num_workers=4,
        pin_memory=True,
    )
    print("n_test:", len(test_loader.dataset))

    # set up training configs
    model.to(device)
    MAE = nn.L1Loss()

    # evaluation and store the model
    model.eval()
    # store the label and prediction pairs
    cubic_label = [] # space group 195 <= i <= 230
    cubic_output = []
    cubic_ideal = []

    hexa_label = [] # space group 168 <= i <= 194
    hexa_output = []
    hexa_ideal = []

    trig_label = [] # space group 143 <= i <= 167
    trig_output = []
    trig_ideal = []

    tetr_label = [] # space group 75 <= i <= 142
    tetr_output = []
    tetr_ideal = []
    tetr_feat = []

    orth_label = [] # space group 16 <= i <= 74
    orth_output = []
    orth_ideal = []

    mono_label = [] # space group 3 <= i <= 15
    mono_output = []
    mono_ideal = []

    tric_label = [] # space group 1 <= i <= 2
    tric_output = []
    tric_ideal = []
    tric_feat = []

    i = 0
    mae_list =[]
    frob_list = []
    percen_list = []
    out_list = []

    for data in tqdm(test_loader):
        structure, mask, equality, labels, group = data
        structure, mask, equality, labels = structure.to(device), mask.to(device), equality.to(device), labels.to(device)
        add_feat_mask = (abs(dataset_test[i]['feature_mask_ori'].detach().clone()) > 1.0).float().to(device)

        if args.model == "ceitnet":
            outputs = model(structure, mask, equality, add_feat_mask).view(3, 6).cpu().detach()
        else:
            raise ValueError(f"Model {args.model} not supported")
        # Optional: Apply group-average-derived strict zero mask to ceitnet predictions BEFORE evaluation.
        if getattr(args, "zero_mask", False):
            pmg_structure = dataset_test[i]["structure"]
            sym_dataset = dataset_test[i]["sym_dataset"]
            cart_rots = get_cartesian_rotations_from_sym_dataset(pmg_structure, sym_dataset)
            forced_zero_333 = infer_forced_zero_mask_symmetric_rank3_jk(
                cart_rots,
                rcond=1e-10,
                zero_tol=1e-8,
            )
            forced_zero_36 = forced_zero_mask_rank3_to_voigt36(forced_zero_333)
            forced_zero_t = torch.tensor(forced_zero_36, dtype=torch.bool)
            outputs = outputs.clone()
            outputs[forced_zero_t] = 0.0

        out_list.append(outputs)

        labels = labels.cpu()
        mae_list.append(abs(outputs - labels).view(-1).mean())
        frob_ = ((labels.view(-1) - outputs.view(-1)) ** 2).sum() ** 0.5
        frob_norm = (labels.view(-1) ** 2).sum() ** 0.5
        frob_list.append(frob_)
        percen_list.append(frob_/(frob_norm+1e-5))
        space_g = dataset_test[i]['sym_dataset'].number
        if space_g >= 195:
            cubic_label.append(labels.view(3, 6))
            cubic_output.append(outputs)
            cubic_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        elif space_g >= 168:
            hexa_label.append(labels.view(3, 6))
            hexa_output.append(outputs)
            hexa_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        elif space_g >= 143:
            trig_label.append(labels.view(3, 6))
            trig_output.append(outputs)
            trig_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        elif space_g >= 75:
            tetr_label.append(labels.view(3, 6))
            tetr_output.append(outputs)
            tetr_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        elif space_g >= 16:
            orth_label.append(labels.view(3, 6))
            orth_output.append(outputs)
            orth_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        elif space_g >= 3:
            mono_label.append(labels.view(3, 6))
            mono_output.append(outputs)
            mono_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))
        else:
            tric_label.append(labels.view(3, 6))
            tric_output.append(outputs)
            tric_ideal.append(torch.tensor(dataset_test[i]['ideal_matrix']))

        i += 1

    print("MAE ", np.mean(mae_list))
    print("M_Frob", np.mean(frob_list))
    percen_list = np.array(percen_list)
    print("EwT 25", np.sum(percen_list < 0.25) / percen_list.shape[0])
    print("EwT 10", np.sum(percen_list < 0.1) / percen_list.shape[0])
    print("EwT 5", np.sum(percen_list < 0.05) / percen_list.shape[0])
    print("EwT 2", np.sum(percen_list < 0.02) / percen_list.shape[0])
    
    # evaluation for cubic system
    print("total number of cubic systems", len(cubic_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0

    for i in range(len(cubic_label)):
        label = cubic_label[i].view(-1)
        pred = cubic_output[i].view(-1)
        ideal = cubic_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-4).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-4).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(label[px] - label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break


    # CUBIC label errors
    print("CUBIC systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(cubic_label))

    # evaluation for Tetragonal system
    print("total number of Tetragonal systems", len(tetr_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(tetr_label)):
        label = tetr_label[i].view(-1)
        pred = tetr_output[i].view(-1)
        ideal = tetr_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-4).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-4).any():
            pred_sym_error += 1
            # print(abs(pred[zero_entries]))
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(label[px] - label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break
        

    # label errors
    print("Tetragonal systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(tetr_label))

    # evaluation for hexagonal system
    print("total number of hexagonal systems", len(hexa_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(hexa_label)):
        label = hexa_label[i].view(-1)
        pred = hexa_output[i].view(-1)
        ideal = hexa_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-4).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-4).any():
            pred_sym_error += 1
            # print(abs(pred[zero_entries]))
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(label[px] - label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break
        

    # label errors
    print("Hexagonal systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(hexa_label))


    # evaluation for trig system
    print("total number of trigonal systems", len(trig_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(trig_label)):
        label = trig_label[i].view(-1)
        pred = trig_output[i].view(-1)
        ideal = trig_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-4).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-4).any():
            pred_sym_error += 1
            # print(abs(pred[zero_entries]))
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(label[px] - label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break
        

    # label errors
    print("Trigonal systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(trig_label))


    # evaluation for Orthorhombic system
    print("total number of Orthorhombic systems", len(orth_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(orth_label)):
        label = orth_label[i].view(-1)
        pred = orth_output[i].view(-1)
        ideal = orth_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-4).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-4).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(label[px] - label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break

    # label errors
    print("Orthorhombic systems: Label symmetry error - Zero Error", label_sym_error, "Inequality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Inequality Error", pred_equi_error, "Fnorm", F_error/len(orth_label))

    # evaluation for Orthorhombic system
    print("total number of Monoclinic systems", len(mono_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(mono_label)):
        label = mono_label[i].view(-1)
        pred = mono_output[i].view(-1)
        ideal = mono_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-4).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-4).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(label[px] - label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break

        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break

    # label errors
    print("Monoclinic systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(mono_label))

    # evaluation for Triclinic system
    print("total number of Triclinic systems", len(tric_label))
    label_sym_error = 0
    label_equi_error = 0
    F_error = 0
    pred_sym_error = 0
    pred_equi_error = 0
    for i in range(len(tric_label)):
        label = tric_label[i].view(-1)
        pred = tric_output[i].view(-1)
        ideal = tric_ideal[i].view(-1)

        F_error += ((label - pred) ** 2).sum() ** 0.5

        zero_entries = abs(ideal) < 1.0
        if (abs(label[zero_entries]) > 1e-4).any():
            label_sym_error += 1
        if (abs(pred[zero_entries]) > 1e-4).any():
            pred_sym_error += 1
        
        # equality analysis
        label_mask = abs(ideal) > 1.0
        label = label[label_mask]
        pred = pred[label_mask]
        ideal = ideal[label_mask]
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(label[px] - label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] / ideal[py] - 1.0) < 1e-4:
                    if abs(pred[px] - pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break

        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(label[px] + label[py]) > 1e-4:
                        label_equi_error += 1
                        break
        
        for px in range(label.shape[0] - 1):
            for py in range(px, label.shape[0]):
                if abs(ideal[px] + ideal[py]) < 1e-4 * abs(ideal[px]):
                    if abs(pred[px] + pred[py]) > 1e-4:
                        pred_equi_error += 1
                        break

    # label errors
    print("Triclinic systems: Label symmetry error - Zero Error", label_sym_error, "Equality Error", label_equi_error)
    # Prediction errors
    print("Prediction error - Zero Error", pred_sym_error, "Equality Error", pred_equi_error, "Fnorm", F_error/len(tric_label))
    
    return



def main():
    parser = argparse.ArgumentParser(description='Training script')

    # Define command-line arguments
    # training parameters
    parser.add_argument('--epochs', type=int, default=200, help='Number of training epochs')
    parser.add_argument('--batch_size', type=int, default=64, help='batch size of training and evaluating')
    parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate')
    parser.add_argument('--weight_decay', type=float, default=1e-05, help='weight decay')
    parser.add_argument('--loss', type=str, default='huber', help='mse or l1 or huber')
    parser.add_argument('--model', type=str, default='ceitnet', help='comformer or megnet')
    parser.add_argument('--project', type=str, default='test', help='name of project for wandb visualization')
    parser.add_argument('--name', type=str, default='test', help='name of project for storage')
    parser.add_argument('--reduce_cell', type=bool, default=False, help='reduce the cell into irreducible atom sets')
    # dataset parameters
    parser.add_argument('--split_seed', type=int, default=32, help='the random seed of spliting data')
    parser.add_argument('--train_ratio', type=float, default=0.8, help='training ratio used in data split')
    parser.add_argument('--val_ratio', type=float, default=0.1, help='evaluate ratio used in data split')
    parser.add_argument('--test_ratio', type=float, default=0.1, help='test ratio used in data split')
    parser.add_argument('--target', type=str, default='piezoelectric', help='dielectric, piezoelectric, or elastic')
    parser.add_argument('--threshold', type=float, default=100., help='threshold to remove samples')
    parser.add_argument('--use_corrected_structure', type=bool, default=True, help='correct input structure or not')
    parser.add_argument('--load_preprocessed', type=bool, default=True, help='load previous processed dataset')
    # checkpoint parameters
    parser.add_argument('--ckpt_path', type=str, default='repro_data/pretrained_ckpts/piezo.pt', help='Path to a checkpoint .pt file (overrides ckpt_kind)')
    parser.add_argument('--zero_mask', action='store_true', help='apply group-average-derived strict symmetry zero mask to predictions during eval.')
    args = parser.parse_args()

    print('Training settings:')
    print(f'  Epochs: {args.epochs}')
    print(f'  Learning rate: {args.learning_rate}')
    print(args)
    torch.manual_seed(args.split_seed)
    torch.cuda.manual_seed_all(args.split_seed)
    # load the model
    if args.model == "ceitnet":
        model = CEITNet(args)
    else:
        raise ValueError(f"Model {args.model} not supported")

    state_dict = torch.load(args.ckpt_path)
    model.load_state_dict(state_dict)

    test(model, args)

if __name__ == "__main__":
    main()
