import sys
import os
import json
import yaml
from megatron.data import indexed_dataset
import numpy as np

def listdir(path, suffix=None, prefix=None):
    if not os.path.exists(path) or not os.path.isdir(path):
        return []
    files = []
    filelist = os.listdir(path)
    for filename in filelist:
        filepath = os.path.join(path, filename)
        if os.path.isdir(filepath):
            files.extend(listdir(filepath, suffix, prefix))
        else:
            suffix_fit = True
            prefix_fit = True
            filename = os.path.split(filepath)[-1]
            if suffix is not None:
                if suffix not in filename:
                    suffix_fit = False
            if prefix is not None:
                if not filename.startswith(prefix):
                    prefix_fit = False
            if suffix_fit and prefix_fit:
                files.append(filepath)
    return files


def update_yaml(file_path, eval_iters, ckptname):

    with open(file_path, 'r') as f:
        config = yaml.safe_load(f)
   
    def merge_dict(dest, src):
        for key, value in src.items():
            if isinstance(value, dict) and key in dest:
                merge_dict(dest[key], value)
            else:
                dest[key] = value
        return dest
    
    config = merge_dict(config, {"eval-iters": eval_iters})
    
    with open("./config/"  + ckptname + ".yml", 'w') as f:
        yaml.dump(config, f, sort_keys=False, default_flow_style=False)

def update_local_yaml(file_path, loadptname, binname, ckptname):

    with open(file_path, 'r') as f:
        config = yaml.safe_load(f)
   
    def merge_dict(dest, src):
        for key, value in src.items():
            if isinstance(value, dict) and key in dest:
                merge_dict(dest[key], value)
            else:
                dest[key] = value
        return dest

    binnamelist = []
    binnamelist.append(binname)

    config = merge_dict(config, {"load": loadptname})
    config = merge_dict(config, {"train-data-paths": binnamelist})
    config = merge_dict(config, {"test-data-paths": binnamelist})
    config = merge_dict(config, {"valid-data-paths": binnamelist})
  
    with open("./config/"  + ckptname + ".yml", 'w') as f:
        yaml.dump(config, f, sort_keys=False, default_flow_style=False)


if __name__ == '__main__':

    filelist = sys.argv[1]
    file_list = listdir(filelist, suffix='.json')
    print(file_list)

    for file in file_list:
        current_file_dir = os.path.dirname(file)
        outputdir = current_file_dir.replace("eval_results", "format_eval_results")
        outputname = file.replace("eval_results", "format_eval_results") + "l"
        print(outputname)
        outputbindir = current_file_dir.replace("eval_results", "eval_bin")
        outputbinname = file.replace("eval_results", "eval_bin")[:-5]
        print(outputbinname)
                
        cwd = os.getcwd()
        ckptname = file[13:-13]

        loadptname = cwd + "/1B/ckpt-1b/" + ckptname
        if not os.path.exists(loadptname):   
            continue
            
        if os.path.exists(outputname):
            print("file exists", outputname)
            continue
        if not os.path.exists(outputdir):
            os.makedirs(outputdir)

        fout = open(outputname, "w", encoding="utf-8")
        with open(file, "r", encoding="utf-8") as f:
            for line in f:
                output = {}
                obj = json.loads(line.strip())
                for i in range(len(obj["all_responses"])):
                    output["text"] = obj["all_responses"][i]['response']
                    if output:
                        fout.write(json.dumps(output, ensure_ascii=False) + '\n')
        # tokeners
        cmd = "python " + cwd + "/gpt-neox-tokeners/tools/datasets/preprocess_data.py --input " +  outputname + " --output-prefix " +  outputbinname  + " --dataset-impl mmap --tokenizer-type TiktokenTokenizer --append-eod --workers 60"   
        print(cmd)
        os.system(cmd)
        outputbinname = outputbinname + "_tokeners"

        dataset = indexed_dataset.make_dataset(outputbinname, "mmap")
        size = np.sum(dataset.sizes)
        print("Dataset size in tokens is", size)
        batch_size = 16384
        eval_iters = int(size / batch_size)
        print(eval_iters)

        binname = cwd + "/" + outputbinname
        print("binname:", binname)

        loadptname = cwd + "/1B/ckpt-1b/" + ckptname 
        print("loadptname: ", loadptname)
        update_local_yaml("./dataconfig.yml",loadptname, binname, ckptname+"_local")
        update_yaml("./modelconfig.yml", eval_iters, ckptname)

        ## eval
        cmd = "python " + cwd + "/gpt-neox-eval/deepy.py " + cwd + "/gpt-neox-eval/train.py " + "./config/" + ckptname + ".yml " + "./config/" + ckptname + "_local" + ".yml"  
        print(cmd)   
        os.system(cmd)
 


    
