from ase.io import read, write
from ase.calculators.lammpslib import LAMMPSlib
from tersoff import *
from sqnm.vcsqnm_for_ase import aseOptimizer
import numpy as np
import matplotlib.pyplot as plt
import json
from ase.units import GPa
from rings import calc_rings
import multiprocessing
import os
import argparse

el_charge = {
        'Si': 2,
        'O': -1,
}

def calc_avg_ring_sizes(ats, target_ars=None):
    # NOTE: The bonds argument is new. I did not use it in my results. 
    # I think it is stanard practice though to only consider Si-O bonds
    # Number of atoms is thus always even. Usually, the number of Si atoms is given as ring size (n_at/2)

    # find rings
    rings = [calc_rings(at, repeat=(2, 2, 2), bonds=[('Si', 'O')]) for at in ats]
    # count Si atoms
    ringsizes = [[len([i for i in r if at.get_chemical_symbols()[i] == 'Si']) for r in rr] for at, rr in zip(ats, rings)]
    # mean ring size
    return [np.mean(rs) for rs in ringsizes]


def write_rsd_json(avg_ring_sizes, filename):
    """
    Write average ring sizes to JSON file in the specified format.
    
    Args:
        avg_ring_sizes (list): List of average ring sizes calculated from structures
        filename (str): Output JSON filename
    """
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    data = [{"avg-ring-size": float(avg_size)} for avg_size in avg_ring_sizes]
    
    with open(filename, 'w') as f:
        json.dump(data, f, indent=2)
    
    print(f"Ring size data written to {filename}")


def main():
    plt.rcParams.update({'font.size': 14})
    
    parser = argparse.ArgumentParser(description='Analyze SiO2 structure and calculate average ring sizes')
    parser.add_argument('file', type=str, help='Path to extxyz file to analyze')
    args = parser.parse_args()
    
    sample_file = args.file
    
    # Get directory and base filename
    directory = os.path.dirname(sample_file)
    base_name = os.path.splitext(os.path.basename(sample_file))[0]
    save_file = os.path.join(directory, f'{base_name}-rsd.json')

    pred_ats = read(sample_file, index=':')

    # to remove ghost atoms
    pred_ats = [at[[a.index for a in at if a.symbol != 'X']] for at in pred_ats]

    # to filter non charge balanced structures
    # q_tots = [sum([el_charge[at.symbol] for at in ats]) for ats in pred_ats]
    # ats = [at for at, q in zip(pred_ats, q_tots) if abs(q) == 0]

    avg_ring_sizes = calc_avg_ring_sizes(pred_ats)
    write_rsd_json(avg_ring_sizes, save_file)


if __name__ == '__main__':
    main()
