import argparse
import logging

import numpy as np
import torch
from LorentzMACE.data.utils import unpack_configs_from_hdf5
from LorentzMACE.tools import torch_geometric

from LorentzMACE import data, tools, modules


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument('--configs', help='path to h5 file', required=True)
    parser.add_argument('--model', help='path to model', required=True)
    parser.add_argument('--output', help='output path', required=True)
    parser.add_argument('--r_max_in',
                        help='distance cutoff (in Ang)',
                        type=float,
                        default=0.0)
    parser.add_argument('--r_max_out',
                        help='distance cutoff (in Ang)',
                        type=float,
                        default=100000)
    parser.add_argument('--device',
                        help='select device',
                        type=str,
                        choices=['cpu', 'cuda'],
                        default='cpu')
    parser.add_argument('--default_dtype',
                        help='set default dtype',
                        type=str,
                        choices=['float32', 'float64'],
                        default='float64')
    parser.add_argument('--batch_size',
                        help='batch size',
                        type=int,
                        default=64)
    return parser.parse_args()


def main():
    args = parse_args()
    tools.set_default_dtype(args.default_dtype)
    device = tools.init_device(args.device)

    # Load data and prepare input
    configs = unpack_configs_from_hdf5(args.configs)
    data_loader = torch_geometric.dataloader.DataLoader(
        dataset=[
            data.AtomicData.from_config(c,
                                        cutoff_in=args.r_max_in,
                                        cutoff_out=args.r_max_out)
            for c in configs
        ],
        batch_size=args.batch_size,
        shuffle=False,
        drop_last=False,
    )

    # Load model
    model = torch.load(f=args.model, map_location=device)

    # Collect data
    energies_list = []
    accuracies_list = []

    for batch in data_loader:
        batch = batch.to(device)
        output = model(batch, training=False)
        acc = torch.round((torch.argmax(output['energy'].softmax(-1), dim=-1)
                           == batch['signal']).float().mean() * 100)
        accuracies_list.append(tools.to_numpy(acc.unsqueeze(-1)))
        energies_list.append(tools.to_numpy(output['energy']))

    energies = np.concatenate(energies_list, axis=0)
    accuracies = np.concatenate(accuracies_list, axis=0)
    mean_accuracy = accuracies.mean()
    logging.info(f'Mean accuracy: {mean_accuracy:.2f}%')
    np.save(f'{args.output}_energies.npy', energies)
    np.save(f'{args.output}_accuracies.npy', mean_accuracy)


if __name__ == '__main__':
    main()