from ase.io import read, write
from ase.calculators.lammpslib import LAMMPSlib
from tersoff import *
from sqnm.vcsqnm_for_ase import aseOptimizer
# from lmpio import *
import numpy as np
import matplotlib.pyplot as plt
import json
from ase.units import GPa
from rings import calc_rings
import multiprocessing
from scipy.stats import gaussian_kde
import os
import argparse

el_charge = {
        'Si': 2,
        'O': -1,
}


def main():
    parser = argparse.ArgumentParser(description='Analyze SiO2 structure and calculate elastic constants')
    parser.add_argument('file', type=str, help='Path to extxyz file to analyze')
    args = parser.parse_args()
    
    file = args.file
    samples = read(file, index=':')
    samples = [at[[a.index for a in at if a.symbol != 'X']] for at in samples]
    properties = []
    for ref_ats in samples:
        c11, c12, c14 = calc_c(ref_ats, delta=2.e-2, force_tol=0.05)
        properties.append({
            'C11 [GPa]': c11, 'C12 [GPa]': c12, 'C44 [GPa]': c14
        })
    directory = os.path.dirname(file)
    base_name = os.path.splitext(os.path.basename(file))[0]
    save_file = os.path.join(directory, f'{base_name}-c.json')
    write_json(properties, save_file)


def write_json(data, filename):
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    with open(filename, 'w') as f:
        json.dump(data, f, indent=2)
    print(f'Properties written to {filename}')


def calc_es(ats):
    tats = ats.copy()
    tats.calc = TersoffCalc()
    e0 = tats.get_potential_energy()
    opt = aseOptimizer(tats, 
        vc_relax=True, 
        force_tol=1e-2,
        initial_step_size=-0.001,
        nhist_max=10,
        maximalSteps=1500,
    )
    opt.optimize()
    nat = len(tats)
    return e0 / nat, tats.get_potential_energy() / nat


def calc_c(ats, delta=2.e-2, force_tol=0.05):
    tats = ats.copy()
    tats.calc = TersoffCalc()
    opt = aseOptimizer(tats, 
                vc_relax=True, 
                force_tol=force_tol,
                initial_step_size=-0.001,
                nhist_max=10,
                maximalSteps=1500,
            )
    opt.optimize()

    stress_0 = tats.get_stress(voigt=False)
    # print()
    # print(stress_0)
    # print()

    c = np.zeros((3, 3, 3, 3))

    for i in range(3):
        for j in range(3):
            stress_lr = [stress_0] # TODO: put back to two point fd
            for lr in [+1]:
                s = tats.copy()
                s.calc = tats.calc
                strain = np.eye(3)
                strain[i, j] += delta * lr
                s.set_cell(s.get_cell(complete=True) @ strain, scale_atoms=True)
                opt = aseOptimizer(s, 
                            vc_relax=False, 
                            force_tol=force_tol,
                            initial_step_size=-0.001,
                            nhist_max=10,
                            maximalSteps=1500,
                        )
                opt.optimize()
                stress = s.get_stress(voigt=False)
                stress_lr.append(stress)

            stress_l, stress_r = stress_lr
            c[:, :, i, j] = (stress_r - stress_l) / (1 * delta)

            # print()
            # print(i, j)
            # print(stress_l - stress_0)
            # print(stress_r - stress_0)
            # print(np.linalg.det(stress), np.linalg.det(stress - stress_0))
            # print(c[:, :, i, j])

    # print()
    # for i in range(3):
    #     for j in range(3):
    #         for k in range(3):
    #             print(' '.join(f'{x:.5e}'  for x in c[i, j, k, :]))

    cc_idx = [
        (0, 0), # xx
        (1, 1), # yy
        (2, 2), # zz
        (1, 2), # yz
        (2, 0), # zx
        (0, 1)  # xy
    ]
    cc = np.zeros((6, 6))   
    for i in range(6):
        for j in range(6):
            # c[i,j,k,l] should be equal to c[i,j,l,k]
            # c[i,j,k,l] = c[j,i,k,l] because stress is symmetric
            ca = c[cc_idx[i][0], 
                   cc_idx[i][1], 
                   cc_idx[j][0], 
                   cc_idx[j][1]]
            cb = c[cc_idx[i][0], 
                   cc_idx[i][1], 
                   cc_idx[j][1], 
                   cc_idx[j][0]]
            # print(i, j, ca, cb, ca-cb)
            cc[i, j] = (ca + cb) / 2


    c11s = [cc[0,0], cc[1,1], cc[2,2]]
    c12s = [cc[0,1], cc[1,2], cc[2,0], cc[1,0], cc[2,1], cc[0,2]]
    c44s = [cc[3,3], cc[4,4], cc[5,5]]

    return np.mean(c11s) / GPa, np.mean(c12s) / GPa, np.mean(c44s) / GPa

if __name__ == '__main__':
    main()
