###########################################################################################
# Script for evaluating configurations contained in an xyz file with a trained model
# Authors: Ilyes Batatia, Gregor Simm
# This program is distributed under the MIT License (see MIT.md)
###########################################################################################

import argparse
import ase.data
import ase.io
import numpy as np
import torch
from mace import data
from mace.cli.convert_e3nn_cueq import run as run_e3nn_to_cueq
from mace.modules.utils import extract_invariant
from mace.tools import torch_geometric, torch_tools, utils
import torch.nn as nn
from torch.nn.functional import l1_loss
from tqdm import tqdm


def eval_model(model, atoms_list, batch_size, device):
    model.eval()
    mse_loss_fn = nn.MSELoss()
    configs = [data.config_from_atoms(atoms) for atoms in atoms_list]
    z_table = utils.AtomicNumberTable([int(z) for z in model.atomic_numbers])
    
    try:
        heads = model.heads
    except AttributeError:
        heads = None
    
    for atoms, config in zip(atoms_list, configs):
        config.properties['energy']=torch.tensor(atoms.get_potential_energy(), dtype=torch.get_default_dtype())
        config.properties['forces']=torch.tensor(atoms.get_forces(), dtype=torch.get_default_dtype())
    
    data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[
            data.AtomicData.from_config(
                config, z_table=z_table, cutoff=float(model.r_max), heads=heads
            )
            for config in configs
        ],
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
    )
    
    
    # Collect data
    energies_list = []
    forces_list = []
    contributions_list = []
    descriptors_list = []
    node_energies_list = []
    stresses_list = []
    forces_collection = []
    
    # with torch.no_grad():
    for batch in tqdm(data_loader):
        batch = batch.to(device)
        output = model(batch.to_dict(), compute_stress=False)
        forces = np.split(
            torch_tools.to_numpy(output["forces"]),
            indices_or_sections=batch.ptr[1:],
            axis=0,
        )
        forces_collection.append(forces[:-1])  # drop last as its empty
        # energy_mse = mse_loss(batch.energy, output['energy'])
        # forces_mae = mse_loss(batch.forces, output['forces'])
        # energies_list.append(energy_mae)
        # forces_list.append(forces_mae)
        # energy_mse = mse_loss_fn(batch.energy, output['energy'])
        # forces_mse = mse_loss_fn(batch.forces, output['forces'])
        # energy_loss = torch.sqrt(energy_mse)
        # forces_loss = torch.sqrt(forces_mse)
        
        energy_loss = l1_loss(batch.energy, output['energy'])
        forces_loss = l1_loss(batch.forces, output['forces'])
        
        # energies_list.append(energy_loss)
        # forces_list.append(forces_loss)
        # torch.cuda.empty_cache()
        energies_list.append(energy_loss.item())  # CHANGE: Use .item() to extract float and release graph
        forces_list.append(forces_loss.item())    # CHANGE: Use .item() to extract float and release graph
        
        # Optional: Help GC by deleting references explicitly (not always necessary)
        del output
    energy_loss = sum(energies_list)/len(energies_list)
    forces_loss = sum(forces_list)/len(forces_list)
    return energy_loss, forces_loss
    

def main():
    parser = argparse.ArgumentParser(description="Evaluate configurations with a trained MACE model")
    parser.add_argument('--exp_name', type=str, default='experiment',
                       help='Experiment name for output file')
    parser.add_argument('--batch_size', type=int, default=2,
                       help='Batch size for data loading')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                       help='Device to run the model on (cuda/cpu)')
    
    args = parser.parse_args()

    torch_tools.set_default_dtype('float64')
    
    config_paths = ["datasets/3BPA/test_300K.xyz", "datasets/3BPA/test_600K.xyz", "datasets/3BPA/test_1200K.xyz", "datasets/3BPA/test_dih_beta120.xyz", "datasets/3BPA/test_dih_beta150.xyz", "datasets/3BPA/test_dih_beta180.xyz"]
    model_path = f'./checkpoints/{args.exp_name}_run-123_stagetwo.model'
    
    # Load model
    model = torch.load(model_path, map_location=args.device)
    model = model.to(args.device)

    # Evaluate configurations
    with open(f'evaluation/results_{args.exp_name}.txt', 'w+') as f:
        for config_path in config_paths:
            atoms_list = ase.io.read(config_path, index=":")
            energy_loss, forces_loss = eval_model(model, atoms_list, args.batch_size, args.device)
            print(f"Model: {model_path}, Config: {config_path}", file=f)
            print(f'energy: {energy_loss:.6f}', file=f)
            print(f'force: {forces_loss:.6f}', file=f)
            print('', file=f)
            print(f"Model: {model_path}, Config: {config_path}")
            print(f'energy: {energy_loss:.6f}')
            print(f'force: {forces_loss:.6f}')

if __name__ == "__main__":
    main()