#!/usr/bin/python
import argparse
import math
import random
import ase.io
import pathlib
import csv  

def generate_idx_dict(idx_file:str):
    file_idx = pathlib.Path(idx_file)
    idx_dict = {}
    with open(file_idx, 'r') as fp:
        lines = fp.readlines() 
        for i, line in enumerate(lines):
            if i%2==0:
                key = line.strip().split('/')[-1].split('.')[0]
            else:
                idx_dict[key] = [int(x) for x in line.strip().split(',')]
    return idx_dict

def get_split_idx(idx_size:int, split_ratio:list):
    idx_list = list(range(idx_size))
    split_total = sum(split_ratio)
    idx_split = [math.ceil(idx_size*x/split_total) for x in split_ratio] 
    while(sum(idx_split) != idx_size):
        idx_split[0]-=1
    random.shuffle(idx_list)
    random.shuffle(idx_list)

    train_idx = idx_list[:idx_split[0]]
    valid_idx = idx_list[idx_split[0]:idx_split[0]+idx_split[1]]
    test_idx= idx_list[idx_split[0]+idx_split[1]:]
    train_idx.sort()
    valid_idx.sort()
    test_idx.sort()
    split_dict = {
        'train' : train_idx,
        'valid' : valid_idx,
        'test' : test_idx
    }
    return split_dict

#def initialize():
#    file_extxyz = pathlib.Path(args.output)
#    file_extxyz.unlink(missing_ok=True)

def create_extxyz_and_index(args, structures, prefix='train', suffix=''):
    structuresdir=pathlib.Path(f'{args.out}/dataset{suffix}')
    structuresdir.mkdir(parents=True, exist_ok=True)

    for structure in structures:
        structure.write(f'{args.out}/dataset{suffix}/{prefix}.xyz', format = 'extxyz', append = True)
    
if __name__ == '__main__':
    #
    # python HfO_Split.py -s _hfo2 -o .
    #

    SPLIT_RATIO = [8,1,1]
    SNAPSHOTS_PER_MD_STEPS=[230, 267, 167, 1000, 667, 667, 500]
    D_SKIP=[10,0,0,0,0,0,0]
    D_STEP=[3,3,3,3,6,6,6]
    D_INIT=[0, 700, 1500, 2000, 5000, 9000, 13000]

    parser = argparse.ArgumentParser(description='Generate xyz files')
    parser.add_argument('-s', '--suffix', type=str, default='',
                    help='Suffix of output files')
    parser.add_argument('-b', '--base', type=str, default='',
                    help='Base directory name for extxyz and index file')
    parser.add_argument('-o', '--out', type=str, default='.',
                    help='Base directory name for output files')
    args = parser.parse_args()


    ### Change this to sample differnet Crystals
    crystal_sets = list(range(1, 13))
    md_sets = list(range(7))
    ###

    outdir=pathlib.Path(f'{args.out}')
    outdir.mkdir(parents=True, exist_ok=True)

    for crystal_set in crystal_sets:
        if crystal_set in [1,2,3,4,5]:
            args.base = "Crystal"
            prefix_test="Testset"
        elif crystal_set in [6,7,8,9,10]:
            args.base = "Random"
            prefix_test="Testset"
        elif crystal_set in [11,12]:
            args.base = "OOD"
            SPLIT_RATIO = [0,0,10]
            prefix_test="OOD"

        structures_total = ase.io.read(f'{args.base}/{crystal_set}.xyz', index=':', format = 'extxyz')
        idx_start = 0

        for md_set in md_sets:
            idx_start = D_INIT[md_set]
            split_idx = get_split_idx(SNAPSHOTS_PER_MD_STEPS[md_set], SPLIT_RATIO)
            structures = structures_total[idx_start + D_SKIP[md_set] : idx_start + D_SKIP[md_set] + D_STEP[md_set] * SNAPSHOTS_PER_MD_STEPS[md_set] : D_STEP[md_set]]

            
            train_structures = [structures[idx] for idx in split_idx['train']]
            valid_structures = [structures[idx] for idx in split_idx['valid']]
            test_structures = [structures[idx] for idx in split_idx['test']]

            create_extxyz_and_index(args, train_structures, prefix='Trainset', suffix=args.suffix)
            create_extxyz_and_index(args, valid_structures, prefix='Validset', suffix=args.suffix)
            create_extxyz_and_index(args, test_structures,  prefix=prefix_test, suffix=args.suffix)
