from transformers import T5ForConditionalGeneration, T5TokenizerFast as T5Tokenizer
import json
import os, shutil, random
from sklearn.utils import shuffle

# internal imports
from src.core.interface import data_generation as datagen, fine_tuning as tuning
from src.utils.data_generation.pretrained.model.trainer import *
from src.utils.data_generation.pretrained.data import datasets as ds

CONF_LOC = "src/core/evaluation/global_config.json"
# train_df1 = make_dataset('common_gen', split='train')
# # train_df2 = make_dataset("ag_news", split="train")
# # train_df3 = make_dataset("cc_news", split="train")
# # train_df = pd.concat([train_df1, train_df2, train_df3])
with open(CONF_LOC) as config_file:
    configuration = json.load(config_file)
    SCHEMA_NAMES = configuration["schema"]
    train_data_percent = 1
    remake_data = False
    for SCHEMA_NAME in SCHEMA_NAMES:
        
        # EVAL_DATASET_PATH = [f"src/data/test_data/squad_format/flight_delay.json"]
        print(f"training model for {SCHEMA_NAME}\n\n")
        if SCHEMA_NAME == 'combined':
            ds.TRAIN_DATASET_PATH = [
                f"src/data/fine_tuning/csr/student_perf/small/train.jsonl",
                f"src/data/fine_tuning/csr/flight_delay/small/train.jsonl",
                f"src/data/fine_tuning/csr/online_delivary/small/train.jsonl"]
            ds.EVAL_DATASET_PATH = [
                f"src/data/fine_tuning/csr/student_perf/small/test.jsonl",
                f"src/data/fine_tuning/csr/flight_delay/small/test.jsonl",
                f"src/data/fine_tuning/csr/online_delivary/small/test.jsonl",]
        else:
            datagen.generate_artificial_data(SCHEMA_NAME, format='csr', folds=0, remake_data=remake_data)
            ds.TRAIN_DATASET_PATH = [f"src/data/fine_tuning/csr/{SCHEMA_NAME}/small/train.jsonl"]
            ds.EVAL_DATASET_PATH = [f"src/data/fine_tuning/csr/{SCHEMA_NAME}/small/test.jsonl"]

        cache_dir = f'src/data/fine_tuning/csr/cache_{train_data_percent}/{SCHEMA_NAME}'
        if not os.path.exists(cache_dir):
            os.makedirs(cache_dir)
        train_df = ds.make_dataset('common_gen', split='train', cache_dir=cache_dir, from_local=True)
        evaluation_df = ds.make_dataset('common_gen', split='validation', cache_dir=cache_dir, from_local=True)
        # test_df = ds.make_dataset('common_gen', split='test', from_local=True)
        # test_df = test_df[:100]
        train_df = shuffle(train_df)
        evaluation_df = shuffle(evaluation_df)
        train_df = train_df[:int(len(train_df) * (train_data_percent/100))]
        eval_df = evaluation_df[:int(len(evaluation_df) * (train_data_percent/100))]
        evaluation_df = shuffle(evaluation_df)
        test_df = evaluation_df[:int(len(evaluation_df) * (train_data_percent/100))]

        model_name = "models/language_model/t5-base"
        tokenizer = T5Tokenizer.from_pretrained(f"{model_name}")
        model = T5ForConditionalGeneration.from_pretrained(f"{model_name}", return_dict=True)

        module = trainer(
            tokenizer=tokenizer, 
            model=model, 
            train_df=train_df, 
            test_df=test_df, 
            val_df=eval_df, 
            source_max_token_len=512, 
            target_max_token_len=512, 
            batch_size=4, 
            max_epochs=2, 
            outputdir="saved_models"
            )

        model = module.train()
        # print(module.evaluate(test_df=eval_df))
        model_dirs = f'models/language_model/{SCHEMA_NAME}/t5-base-{train_data_percent}-per'
        if not os.path.exists(model_dirs):
            os.makedirs(model_dirs)
        module.save_model(model_dir=model_dirs)
        
        # shutil.rmtree(cache_dir)
        # for files in os.listdir(dir):
        #     path = os.path.join(dir, files)
        #     try:
        #         shutil.rmtree(path)
        #     except OSError:
        #         os.remove(path)