'''
- data_generation.py
- This file handles interfacing crucial data generation functions with the proper utilities
'''

# External imports
import os
import random
import shutil

# Internal imports
from src.core.configuration.datagen_conf import *
from src.utils.data_generation import artificial_queries as generator, conll, gmb_formatter as gmb, squad, csr
from src.utils.misc.io import print_list

'''
----------generate_artificial_data----------
- Generates a full set of artificial training data based on the schema and format provided
-----Inputs-----
- schema_name - The schema to base the data off of
- format - The data format to use (default is CoNLL-2003)
- folds - The number of folds to generate (defaults to 5)
- batch_size - The batch size to truncate to (defaults to 0, no truncation)
-----Output-----
- N/A - This function writes the necessary training data to files in src/data
'''
def generate_artificial_data(schema_name, format='conll', folds=FOLD_COUNT, batch_size=0, remake_data=False):
    # If the general queries don't yet exist, make them
    folder_name = 'ner' 
    if format == 'csr':
        folder_name = 'csr'    
    else: 
        folder_name = 'squad'
    
    general_path = f"{ARTIFICIAL_DATA_LOC}/{folder_name}/{schema_name}/general/"
    small_path = f"{ARTIFICIAL_DATA_LOC}/{folder_name}/{schema_name}/small/"
    
    if remake_data or (not os.path.exists(small_path)):
        print("GENERATION: No general dataset detected for the selected schema. Generating general data")
        # Generate general dataset query templates
        # query_templates = generator.generate_templates(type="general")
        
        # # # Fill the general templates & write them to a file
        # generator.fill_templates(query_templates, schema_name, type="general", format=folder_name)
        # # # Format the general queries
        # format_artificial_data(schema_name, fold_num="general", batch_size=batch_size, format=format)

        # Generate small dataset query templates
        query_templates = generator.generate_templates(type="small")
        # Fill the small dataset templates & write them to a file

        generator.fill_templates(query_templates, schema_name, type="small", format=folder_name)
        # Format the small dataset queries
        format_artificial_data(schema_name, fold_num="small", batch_size=batch_size, format=format)
        
        if format != 'csr':
            source_handcraft_path = "src/data/test_data/{}_format/{}.txt".format(format, schema_name)
            destination_path_general = general_path + "handcrafted.txt"
            destination_path_small = small_path + "handcrafted.txt"
        
            shutil.copy(source_handcraft_path, destination_path_general)
            shutil.copy(source_handcraft_path, destination_path_small)
        

    # Format all the folds
    for i in range(folds):
        format_artificial_data(schema_name, fold_num=i+1, batch_size=batch_size)


'''
----------format_artificial_data----------
- Formats the artificial training dataset for use in fine-tuning
-----Inputs-----
- schema_name - The schema to base the data off of
- format - The data format to use (default is CoNLL-2003)
- folds - The number of folds to generate (defaults to 5)
- batch_size - The batch size to truncate to (defaults to 0, no truncation)
-----Output-----
- N/A - This function writes the necessary training data to files in src/data
'''
def format_artificial_data(schema_name, format="conll", fold_num=0, batch_size=0):
    if format == "conll":
        conll.format_conll(schema_name, fold_num, batch_size)
    if format == "squad":
        squad.format_squad(schema_name, fold_num, batch_size)  
    if format == "csr":
        csr.format_csr(schema_name, fold_num, batch_size)


'''
----------schema_data_exists----------
- Checks to see if the general artificial data exists for the given schema
-----Inputs-----
- schema_name - The schema to base the data off of
-----Output-----
- data_exists - True/false as to whether the data exists for the schema
'''
def schema_data_exists(schema_name, task):
    if (os.path.exists(f"{ARTIFICIAL_DATA_LOC}/{task}/{schema_name}/general")):
        return True
    return False


'''
----------get_data----------
- Loads a compatible dataset to a format that can be easily read by a trainer
-----Inputs-----
- dataset - The dataset to use
- format - The data format to use (default is CoNLL-2003)
- batch_size - The batch size to truncate to (defaults to 0, no truncation)
-----Output-----
- N/A - This function writes the necessary training data to files in src/data
'''
def get_data(dataset, format="conll", batch_size=0):
    if (dataset == "gmb"):
        return gmb.import_data()
    return
