from data_generator_val import DataGenerator
import numpy as np
import json
import os
import pandas as pd
import pickle
import time
import argparse
import traceback
import psutil


from adbench.myutils import Utils

import train
import utils.data_handling
import utils.file_handling
import utils.monitor_train

import multiprocessing as mp
import sys
from pathlib import Path

def safe_train(config, file_path_toml, is_automatric, config_folder, config_filename):
    csv_path = config["code_path"] + "training_monitor.csv"
    last_id = utils.monitor_train.get_training_id(csv_path)
    result_path = config["result_path"] + str(last_id + 1) + "_" + config["model_type"] + "_" + config["backbone_model"] + "/"
    utils.file_handling.create_folder(result_path)
    utils.file_handling.copy_file(file_path_toml, result_path)
    monitor = utils.monitor_train.TrainingMonitor(csv_path, result_path, last_id + 1, config["model_type"], config["backbone_model"])
    utils.monitor_train.check_disk_usage(config["model_path"])
    if is_automatric:
        info_text = "run all"
        json_data = utils.file_handling.read_json(config_folder + "experiment.json")
        json_data["train_id_list"].append(int(last_id + 1))
        json_data["run_config"] = config_filename
        print(json_data)
        utils.file_handling.write_json(config_folder + "experiment.json", json_data)
    else:
        if config["save"]:
            info_text = input("What is spezial about this run? ")
        else:
            info_text = "Run not saved - just debugging"
    monitor.add_info_text(info_text)
    if config["save"]:
        monitor.save_train_start()
    try:
        main(monitor, config)
    except Exception as e:
        utils.file_handling.delete_empty_folder(result_path)
        print("Terminated unexpected")
        monitor.set_error()
        print(e)
        print(traceback.format_exc())
    finally:
        if config["save"]:
            monitor.save_train_end()
        if is_automatric:
            json_data = utils.file_handling.read_json(config_folder + "experiment.json")
            json_data["finish_config"].append(config_filename)
            utils.file_handling.write_json(config_folder + "experiment.json", json_data)


def main(monitor, config):
    start_seed = config["start_seed"]
    end_seed = config["end_seed"]

    for seed in range(start_seed, end_seed):
        datagenerator = DataGenerator(seed = seed, test_size=0.5, normal=True)
        dataset_list = utils.data_handling.create_dataset_list(datagenerator, config["dataset_list_AD"], config["dataset_list"])
        print(sorted(dataset_list))
        
        start = 1
        for dataset in dataset_list:
            if start == 0 and dataset != '13_fraud':
                pass
            else:
                start = 1
                utils.monitor_train.print_console_output(monitor.id, seed, dataset, config["num_epochs"], config["model_type"], config["backbone_model"])
                data_dir = config["data_path"] + "/" + dataset + "/seed_" + str(seed) + ".pkl"
                # print(data_dir)
                if os.path.exists(data_dir):
                    with (open(f"{data_dir}" , "rb")) as data_file:
                        data = pickle.load(data_file)
                    start_time = time.time()
                   
                            
                    print(f"Starte Training für {dataset}")
                    proc = mp.get_context("spawn").Process(
                        target=train.train_loop,
                        args=(monitor, config, data['X_train'], dataset, False, 0, 0, seed)
                    )
                    proc.start()
                    proc.join()  # warten bis fertig
                    print(f"Fertig mit {dataset}\n")
    
                    monitor.add_train_time(dataset, start_time, time.time())

                monitor.add_dataset(dataset)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Settings')
    parser.add_argument('--config', type=int, 
        default=2, help='number config file')
    parser.add_argument('--all', type=int, 
        default=0, help='set 1 to train all in folder')
    parser.add_argument('--folder', type=str, 
        default="./10_3_train_config/", help='folder path of config files for experiment')
    parser.add_argument('--restart', type=int, 
        default=0, help='set 1 to restart experiment, from where it stopped')
    args = parser.parse_args()

    if args.all == 1:
        folder = args.folder
        info_text = input("What is spezial about this experiment? ")
        filenames = os.listdir(folder)
        data = {
                "info": info_text,
                "train_id_list": [],
                "finish_config": [],
                "run_config" : None
            }
        if args.restart == 0:
            utils.file_handling.create_json(folder + "experiment.json", data)

        for config_file in filenames:
            if config_file.endswith('.toml'):
                print(f"Processing {config_file}")
                file_path_toml = folder + config_file
                print(file_path_toml)
                data = utils.file_handling.read_toml(file_path_toml)
                flatten_dict = utils.data_handling.flatten_dict(data, ["General", data["General"]["model_type"], data[data["General"]["model_type"]]["backbone_model"]])
                print(flatten_dict)
                safe_train(flatten_dict, file_path_toml, True, folder, config_file)
    else:
        if args.config == 0:
            filename = "config.toml"
        else:
            filename = "config_" + str(args.config) + ".toml"

        file_path_toml = "./" + filename
        data = utils.file_handling.read_toml(file_path_toml)
        flatten_dict = utils.data_handling.flatten_dict(data, ["General", data["General"]["model_type"], data[data["General"]["model_type"]]["backbone_model"]])
        print(flatten_dict)

        safe_train(flatten_dict, file_path_toml, False, None, None)