import math
import random
import ase.io

def main():
    '''Python 3.8.10 used'''
    #### Create OOD dataset: OOD.xyz, OOD_info.xyz ####
    OOD = ["OOD_MeltQuenchRelax/Melt.xyz","OOD_MeltQuenchRelax/Quench.xyz","OOD_MeltQuenchRelax/Relax.xyz"]
    file = open('OOD_info.txt', 'a')
    stride         = 3 
    _n_tot_sampled = 0
    _n_tot         = 0
    for vpr in OOD:
        aimd    = ase.io.read("./{}".format(vpr), format='extxyz', index = ':')
        n_steps = int(len(aimd))
        _n_tot += n_steps
        indices = [i for i in range(n_steps)]
        
        _n_steps     = 0
        _idx_sampled = []
        for i in indices: 
            if i % stride  == 0: 
                aimd[i].write("OOD.xyz", format="extxyz", append = True)
                _n_steps += 1
                _idx_sampled.append(i)
        _n_tot_sampled += _n_steps
        _str = f'For {vpr} with total of {n_steps} steps, numb of samples = {_n_steps}, indices sampled: {_idx_sampled} \n'
        file.write(_str)

    _str = f'Total: {_n_tot_sampled} steps are sampled from {_n_tot} samples of {OOD}\n'
    file.write(_str)
    file.close()


    #### Create ID dataset: Trainset.xyz, Testset.xyz, Validset.xyz, and Split_Strategy_info.txt ####
    vpr_list = []
    for i in range(1, 93):    vpr_list.append(f"./SiN_compound/{i}.xyz")
    for i in range(93, 107):  vpr_list.append(f"./Si_Only/{i}.xyz")
    for i in range(107, 111): vpr_list.append(f"./N_Only/{i}.xyz")

    train_ratio = 0.8
    file = open(f'Split_Strategy_info.txt', 'a')
    random.seed(0)
    tot_train, tot_valid, tot_test = 0, 0, 0
    for vpr in vpr_list:
        aimd = ase.io.read(vpr, format='extxyz', index=':')
        n_tot_steps  = len(aimd)
        stride       = 3
        inds         = [i for i in range(n_tot_steps)]
        inds_sampled = [i for i in inds if i % stride == 0]
        n_steps      = len(inds_sampled)
        
        # calc. number of samples 
        n_train = int(math.floor(n_steps * train_ratio))
        n_valid = int(math.floor((n_steps - n_train)/2))
        n_test  = n_steps - n_train - n_valid
        
        # random sampling from the list: inds_sampled
        inds_train     = random.sample(inds_sampled, n_train)
        remaining_inds = list(set(inds_sampled)-set(inds_train))
        inds_valid     = random.sample(remaining_inds, n_valid)
        inds_test      = list(set(remaining_inds) - set(inds_valid))

        # write Trainset.xyz, Validset.xyz, Testset.xyz
        for i in inds_train: aimd[i].write(f"Trainset.xyz", format='extxyz', append=True)
        for i in inds_valid: aimd[i].write(f"Validset.xyz", format='extxyz', append=True)
        for i in inds_test : aimd[i].write(f"Testset.xyz" , format='extxyz', append=True)

        # save info on file
        _str1 = f'{vpr}: {n_steps} sampled from total {n_tot_steps} steps with an interval of {stride} steps\n'
        _str2 = f'Train : Valid : Test = {n_train} : {n_valid} : {n_test} = {round(n_train/n_steps*100,1)} : {round(n_valid/n_steps*100,1)} : {round(n_test/n_steps*100,1)} \n'
        _str3 = f' *** Tainset  : {n_train} steps, indices = {inds_train}\n'
        _str4 = f' *** Validset : {n_valid} steps,  indices = {inds_valid}\n'
        _str5 = f' *** Testset  : {n_test} steps,  indices = {inds_test}\n\n'
        file.write(_str1 + _str2 + _str3 + _str4 + _str5)
        
        tot_train += n_train
        tot_valid += n_valid
        tot_test  += n_test

    # save total info 
    _str1 = f'Numb of Total Trainset: {tot_train} steps\n'
    _str2 = f'Numb of Total Validset:  {tot_valid} steps\n'
    _str3 = f'Numb of Total Testset :  {tot_test} steps\n'
    file.write(_str1 + _str2 + _str3)
    file.close()

if __name__ == "__main__":
    main()
