from log import log_class
from prompts import prompter
from plans import planner
import json
from generation import generator
from search_strategy import searcher
from ZC_proxies import ZC_proxy_class
import gc
import traceback
import multiprocessing

# The Text_to_ML class is the main class that organizes the Text-to-ML process
class Text_to_ML() :
    def __init__(self , config , debug=True) :
        # load the configuration file
        self.config = config

        # initialize some configuration parameters
        self.config["planner configuration"] = { }
        self.config["planner configuration"]["maximum_models"] = 2
        self.config["planner configuration"]["maximum_data_preparations"] = 2
        self.config["train_val_ratio"] = 0.1666
        self.config["max_trials"] = 30
        self.config["search_strategy"] = "random_search"
        self.config[
            "search_settings"] = { "skip" : 0.01 , "min_delta_step" : 0 , "patience_step" : 0.5 , "max" : False , "min_delta" : 0 , "patience" : 2 , "max_epochs" : 15 , "check_pointing" : False }
        self.config["execution_time_limit"] = 3600

        # load the LLM model
        if self.config["GPT version"] == "3.5" :
            self.model = "gpt-3.5-turbo"
        if self.config["GPT version"] == "4" :
            self.model = "gpt-4"
        if self.config["GPT version"] == "4o" :
            self.model = "gpt-4o"

        # the logger is an object that manages important information for the Text-to-ML process
        self.logger_instance = log_class(self.config["workspace"])

        # the prompter is an object that assembles the prompts for each LLM inquiry
        self.prompter_instance = prompter(self.config["planner configuration"])

        # settings regarding turning on or off certain phases of the Text-to-ML process
        self.plan_generation = True
        self.module_generation = True
        self.search = True

        # whether to print debug information
        self.debug = debug

    def run(self) :

        # the following code block is for the planning phase
        if self.plan_generation :
            # loads the textual task description from the user
            task_description = [self.config["input data"] , self.config["output data"] , self.config["task objective"] ,
                                self.config["evaluation metrics"] , self.config["files"]]
            # loads the task description with the logger instance
            self.logger_instance.task_description = task_description

            # initializes the planner instance. The planner is an object responsible for generating the plans for the
            # ML task.
            self.planner_instance = planner(task_description , "" , self.config["key"] , self.model ,
                                            self.config["planner configuration"] , self.debug)
            # runs the planner
            print("------------Planning starts!------------")
            self.planner_instance.run()
            print("------------Planning ends!------------")

            # load the generated plans with the logger instance
            self.logger_instance.plans = self.planner_instance.connector_set
            # saves the logger instance
            self.logger_instance.save()

        # the following code block is for the code generation phase
        if self.module_generation :
            # load the generated plans
            task_description = self.logger_instance.read()["task description"]
            plans = self.logger_instance.read()["plans"]

            # Configure the multiprocessing settings to handle complex errors related to CUDA
            multiprocessing.set_start_method('fork' , force=True)

            # initialize the generator instance. The generator is an object responsible for generating the code for
            # the ML task.
            self.generator_instance = generator(task_description , self.config["key"] , self.model ,
                                                workspace=self.config["workspace"],debug=self.debug)

            # if the task is deep learning, generate the code for producing synthetic data for testing the
            # modeling modules
            if plans[0]["connector"][0] == "computer vision" or plans[0]["connector"][
                0] == "natural language processing" :
                print("--------Generating programs for producing synthetic data:--------")
                # configure the generator instance
                self.generator_instance.configure([0 , "simulated_data" , 0])
                # load the logs into the generator instance
                self.generator_instance.load_logs()
                # converts the loaded plans from the logs into a search space
                self.generator_instance.connector_search_space_formation()
                # refreshes the loaded logs
                self.generator_instance.load_logs()
                # runs the generator
                self.generator_instance.run()


            # generate the code for the post processing modules
            print("--------Generating post processing modules:--------")
            # configure the generator instance
            self.generator_instance.configure([0 , "post_processing" , 0] , self.config["train_val_ratio"])
            # load the logs into the generator instance
            self.generator_instance.load_logs()
            # converts the loaded plans from the logs into a search space

            self.generator_instance.run()

            # if the process does not involve search, generate the code for all data preparation modules in the plans
            if not self.search :
                print("--------Generating data preparation modules:--------")
                # iterate over the maximum number of data preparation modules set initially
                for version_choice in range(self.config["planner configuration"]["maximum_data_preparations"]) :
                    # configure the generator instance
                    self.generator_instance.configure([0 , "data_preparation" , version_choice] ,
                                                      self.config["train_val_ratio"])
                    # run the generator
                    self.generator_instance.run()

            # generate the code for all modeling modules in the plans
            print("--------Generating modeling modules:--------")
            # iterate over the maximum number of modeling modules set initially
            for version_choice in range(self.config["planner configuration"]["maximum_models"]) :
                # configure the generator instance
                self.generator_instance.configure([0 , "modeling" , version_choice] , self.config["train_val_ratio"])
                # run the generator
                self.generator_instance.run()

        # the following code block is for the code optimization phase
        if self.search :

            if not self.module_generation:
                task_description = self.logger_instance.read()["task description"]
                self.generator_instance = generator(task_description , self.config["key"] , self.model ,
                                                workspace=self.config["workspace"],debug=self.debug)

            # load the previously generated plans
            plans = self.logger_instance.read()["plans"]

            # load the respective objective function depending on the identified task type
            if plans[0]["connector"][0] == "traditional algorithms" :
                from framework_tabular import objective_function
            if plans[0]["connector"][0] == "computer vision" or plans[0]["connector"][
                0] == "natural language processing" :
                from framework_lightning import objective_function
                import torch

            # load the information regarding previously generated modules
            self.generated_modules = { }
            for component in ["data_preparation" , "modeling"] :
                try :
                    self.generated_modules[component] = self.logger_instance.read()["plans"][0][
                        f"generated_{component}_modules"]
                except :
                    self.generated_modules[component] = []

            # initialize the searcher instance. The searcher is an object responsible for searching the best
            # configuration for the ML task using optimization algorithms.
            self.search_instance = searcher(self.config["workspace"] , self.config["search_strategy"])

            # for tabular tasks, configure the searcher instance.
            if plans[0]["connector"][0] == "traditional algorithms" :
                self.search_instance.configure(0 , mode="combination" ,
                                               maximum_data_preparation_modules=self.config["planner configuration"][
                                                   "maximum_data_preparations"])

            # for deep learning tasks, perform ZC proxy evaluations to filter the modeling modules and initialize the
            # searcher instance.
            if plans[0]["connector"][0] == "computer vision" or plans[0]["connector"][
                0] == "natural language processing" :

                print("--------Performing ZC proxy evaluations:--------")
                # initialize the ZC proxy instance. The ZC proxy is an object responsible for evaluating the modeling
                # modules using the zero-cost proxies.
                self.ZC_proxy_instance = ZC_proxy_class(self.config["workspace"] , debug=False)
                # iterate over the generated modeling modules to evaluate them using the ZC proxy
                for modeling in self.generated_modules["modeling"] :
                    self.ZC_proxy_instance.ZC_proxy_test(modeling)
                # rank the modeling modules based on the ZC proxy evaluations
                ZC_ranking = self.ZC_proxy_instance.rank()
                # configure the searcher instance by filtering the modeling modules based on the ZC proxy evaluations

                self.search_instance.configure(0 , mode="CCF" , filtered_modeling_modules=ZC_ranking ,
                                               maximum_data_preparation_modules=self.config["planner configuration"][
                                                   "maximum_data_preparations"])

            # perform the search iteratively
            print("------------Searching begins!------------")
            trial_counter = 1
            while trial_counter <= self.config["max_trials"] :
                trial_counter += 1
                # designate the configuration to be evaluated
                configuration = self.search_instance.designate()
                print(f"--------Searching configuration in trial {trial_counter - 1}:--------")
                print("The configuration is:", configuration)

                fail_flag = False

                # load the designated configuration for data preparation into the searcher instance
                if plans[0]["connector"][0] == "traditional algorithms" :
                    data_preparation = configuration["data_preparation"]
                if plans[0]["connector"][0] == "computer vision" or plans[0]["connector"][
                    0] == "natural language processing" :
                    data_preparation = configuration[0]["data_preparation"]

                # Configure the multiprocessing settings to handle complex errors related to CUDA
                multiprocessing.set_start_method('fork' , force=True)

                # generate the code for the designated data preparation module if it has not been generated before
                if data_preparation not in self.generated_modules["data_preparation"] :
                    print("--------Generating data preparation modules:--------")

                    self.generator_instance.configure([0 , "data_preparation" , data_preparation] ,
                                                      self.config["train_val_ratio"])

                    # load the logs into the generator instance
                    self.generator_instance.load_logs()

                    # marker: delete it?
                    if plans[0]["connector"][0] == "computer vision" or plans[0]["connector"][
                        0] == "natural language processing" :
                        self.generator_instance.connector_search_space_formation(load_only=True)

                    fail_flag = self.generator_instance.run()
                    if not fail_flag :
                        self.generated_modules["data_preparation"].append(data_preparation)

                # Configure the multiprocessing settings to handle complex errors.
                multiprocessing.set_start_method('spawn' , force=True)

                # load the configuration into tuple format for information exchange among processes under
                # multiprocessing
                if plans[0]["connector"][0] == "traditional algorithms" :
                    configuration_in_tuple = tuple(configuration.values())
                if plans[0]["connector"][0] == "computer vision" or plans[0]["connector"][
                    0] == "natural language processing" :
                    configuration_in_tuple = tuple(value for config in configuration for value in config.values())

                if not fail_flag :
                    try :
                        # clear the memory before running the objective function
                        if plans[0]["connector"][0] == "computer vision" or plans[0]["connector"][
                            0] == "natural language processing" :
                            gc.collect()
                            torch.cuda.empty_cache()
                            torch.cuda.reset_max_memory_allocated()
                            torch.cuda.reset_max_memory_cached()
                            torch.cuda.set_device(0)

                        # Configure the multiprocessing settings to handle complex errors related to CUDA
                        queue = multiprocessing.Queue()

                        # run the objective function under multiprocessing
                        if plans[0]["connector"][0] == "traditional algorithms" :
                            process = multiprocessing.Process(target=objective_function , args=(
                                queue , configuration , self.config["workspace"] , self.config["search_settings"] ,
                                False ,))
                        elif plans[0]["connector"][0] == "computer vision" or plans[0]["connector"][
                            0] == "natural language processing" :
                            process = multiprocessing.Process(target=objective_function , args=(
                                queue , configuration , plans , self.config["workspace"] ,
                                self.config["search_settings"] ,
                                False ,))

                        # Configure the multiprocessing settings to handle complex errors.
                        print("-----Training and testing the configuration-----")
                        process.start()
                        process.join(self.config["execution_time_limit"])
                        while not queue.empty() :
                            score = queue.get()

                        # Save the configuration-scores pair as the search history
                        self.search_instance.update_search_history_full(configuration_in_tuple , score)
                    except Exception as fail :

                        print("-----Search for the configuration failed, skipping...-----")
                        print(f"Reason of failure:\n{fail}.\n{traceback.format_exc()}")

                        # Save the failed configuration as the search history
                        self.search_instance.update_search_history_full(configuration_in_tuple , "failed")

            # print the search results
            search_history = list(self.search_instance.search_history.values())[0]["full"]
            print("------------Search ends!------------")
            print("The search history is:")
            print(search_history)
            if self.config["search_settings"]["max"] :
                print("The best configuration identified is:" ,
                      max((k for k , v in search_history.items() if isinstance(v , float)) , key=search_history.get ,
                          default=None))
                print("The performance of the configuration is:" ,
                      max((v for v in search_history.values() if isinstance(v , (int , float))) , default=None))
            else :
                print("The best configuration identified is:" ,
                      min((k for k , v in search_history.items() if isinstance(v , float)) , key=search_history.get ,
                          default=None))
                print("The performance of the configuration is:" ,
                      min((v for v in search_history.values() if isinstance(v , (int , float))) , default=None))


if __name__ == "__main__" :
    # Configure the multiprocessing settings to handle complex errors related to CUDA
    multiprocessing.set_start_method('spawn' , force=True)
    multiprocessing.freeze_support()

    # specify the path to the configuration file
    config_path = "configs_learning.json"

    # load the configuration file
    with open(config_path , "r") as json_file :
        config = json.load(json_file)

    # initialize the Text_to_ML instance
    instance = Text_to_ML(config)

    # run the Text_to_ML instance
    instance.run()
