from prompts import prompter
from base_modules.code_management import exception_handler
from base_modules.inqury import GPT
import unit_testing
from base_modules.code_management import overtime_kill
import importlib
import gc
import traceback
import time
from log import log_class


from hyperparameters import search_space
from hyperparameters import extractor

class generator():

    def __init__(self,task_description,key,model, workspace, feed_backs=True,explanation=False, debug=False):
        self.task_description = task_description
        self.key = key
        self.model=model
        self.debug=debug
        self.gpt_instance = GPT(self.key, model=self.model)
        self.extractor_instance = extractor(self.task_description, self.key, self.model, workspace=workspace,debug=debug)
        self.search_space_instance = search_space(workspace)
        self.information_instance = log_class(workspace)

        self.workspace=workspace
        self.patience= 10
        self.feedbacks=feed_backs
        self.explanation=explanation
        self.test_sampling_count=3

    def configure(self, choice,train_val_ratio=None):

        self.connector_choice = choice[0]
        self.component_choice = choice[1]
        self.version_choice = choice[2]

        self.target_file_path = self.workspace+f"/{self.component_choice}_{self.connector_choice}_{self.version_choice}.py"
        self.target_traceback_history = []
        self.target_error_history = []
        self.target_code_history = []
        self.explanation_history = []

        self.train_val_ratio=train_val_ratio


    def load_logs(self):
        self.plans = self.information_instance.read()["plans"]
        self.prompter_instance = prompter(generated_connector_plan=self.plans)
        self.prompter_instance.task_description_conversion(self.task_description)
        self.test_instance=unit_testing.tester(self.workspace , self.plans , train_val_ratio=self.train_val_ratio)


    def run(self):

        self.counter=0

        failed_flag = True
        while True:
            try:
                self.set_initial_prompt()
                self.counter+=1
                if self.counter>self.patience:
                    self.update_log_fail()
                    break
                answer = self.gpt_instance.send_inquiry({}, self.prompter_instance.prompt_message)

                extracted_answer = self.gpt_instance.extract_answer(answer, "python code")

                if self.debug:
                    print("-----Extracted code from the response:-----")
                    print(extracted_answer)

                with open(self.target_file_path, "w") as f:
                    f.write(extracted_answer)

                print(f"-----The generated code is saved to the file {self.target_file_path}-----")

                print("-----Testing the generated code-----")
                # inner_break_flag = self.test_with_time_control()
                # marker: the option currently does not work for some data preparation modules: train_loader.dataset[0] cannot be printed during unit testing
                inner_break_flag = self.test_without_time_control()

                if not inner_break_flag:
                    print("-----Retrying-----")
                    self.target_code_history.append(extracted_answer)

                    if self.explanation:
                        self.explain()

                    if self.debug :
                        print("-----Tests history:-----")
                        print("-----Code history-----")
                        print(self.target_code_history)
                        print("-----Error history-----")
                        print(self.target_error_history)
                        print("-----Traceback history-----")
                        print(self.target_traceback_history)
                        if self.explanation:
                            print("-----Explanation history-----")
                            print(self.explanation_history)

                else:
                    failed_flag=False
                    self.update_log_success(self.counter)
                    break

            except Exception as fail:
                print(f"-----Code generation for the {self.version_choice+1}th version of {self.component_choice} of the {self.connector_choice}th connector failed, retrying...-----")
                print(f"Reason of failure:", fail)
                tb = str(traceback.format_exc())
                print(tb)
                time.sleep(1)


        return failed_flag

    def explain(self):

        self.prompter_instance.add_feedbacks(self.target_code_history , self.target_error_history ,
                                             self.target_traceback_history, self.explanation_history)

        self.prompter_instance.explain_prompt()
        if self.debug:
            print("-----Analyzing error-----")
        answer = self.gpt_instance.send_inquiry({}, self.prompter_instance.prompt_message)
        self.explanation_history.append(answer)

    def update_log_success(self,counter):
        try :
            self.plans[self.connector_choice][f"generated_{self.component_choice}_modules"].append(self.version_choice)
            self.plans[self.connector_choice][f"generated_{self.component_choice}_modules"]= list(set(self.plans[self.connector_choice][f"generated_{self.component_choice}_modules"]))
        except :
            self.plans[self.connector_choice][f"generated_{self.component_choice}_modules"] = [self.version_choice]
        try:
            self.plans[self.connector_choice][f"{self.component_choice}_generation_status"][str(self.version_choice)] =counter
        except:
            self.plans[self.connector_choice][f"{self.component_choice}_generation_status"]={}
            self.plans[self.connector_choice][f"{self.component_choice}_generation_status"][str(self.version_choice)] = counter

        self.information_instance.update_plans(self.plans)

    def update_log_fail(self):
        try:
            self.plans[self.connector_choice][f"{self.component_choice}_generation_status"][str(self.version_choice)] ="failed"
        except:
            self.plans[self.connector_choice][f"{self.component_choice}_generation_status"]={}
            self.plans[self.connector_choice][f"{self.component_choice}_generation_status"][str(self.version_choice)] = "failed"
        self.information_instance.update_plans(self.plans)

    def reset_prompt(self):
        self.prompter_instance.reset()
        if self.counter >= 1:
            if self.feedbacks and not self.explanation:

                self.prompter_instance.add_feedbacks(self.target_code_history , self.target_error_history ,
                                                     self.target_traceback_history)

            if self.explanation :
                self.prompter_instance.add_feedbacks(self.target_code_history , self.target_error_history ,
                                                     None,self.explanation_history)

    @exception_handler
    def set_initial_prompt(self):
        if self.component_choice == "simulated_data":
            self.reset_prompt()
            print("-----Analyzing requirements for simulated data-----")
            self.prompter_instance.simulated_data_analysis_prompt(self.connector_choice)
            analysis_answer = self.gpt_instance.send_inquiry({}, self.prompter_instance.prompt_message)
            print("-----Generating the code for producing simulated data-----")
            self.prompter_instance.simulated_data_module_prompt(analysis_answer,self.connector_choice,self.fixed_dimension_size)

        elif self.component_choice == "modeling":
            self.reset_prompt()
            if self.plans[self.connector_choice]["connector"][0]=="computer vision" or self.plans[self.connector_choice]["connector"][0]=="natural language processing":
                self.prompter_instance.modeling_module_prompt(self.connector_choice,self.version_choice,self.fixed_dimension_size)
            elif self.plans[self.connector_choice]["connector"][0]=="traditional algorithms":
                self.prompter_instance.modeling_module_prompt(self.connector_choice,self.version_choice)

        elif self.component_choice == "data_preparation":
            self.reset_prompt()
            if self.plans[self.connector_choice]["connector"][0]=="computer vision" or self.plans[self.connector_choice]["connector"][0]=="natural language processing":
                self.prompter_instance.data_preparation_module_prompt(self.connector_choice,self.version_choice,self.fixed_dimension_size,self.train_val_ratio)
            elif self.plans[self.connector_choice]["connector"][0]=="traditional algorithms":
                self.prompter_instance.data_preparation_module_prompt(self.connector_choice,self.version_choice,train_val_ratio=self.train_val_ratio)

        elif self.component_choice == "post_processing":
            self.reset_prompt()
            if self.plans[self.connector_choice]["connector"][0]=="computer vision" or self.plans[self.connector_choice]["connector"][0]=="natural language processing":
                self.prompter_instance.post_processing_module_prompt(self.connector_choice , self.fixed_dimension_size)
            elif self.plans[self.connector_choice]["connector"][0]=="traditional algorithms":
                self.prompter_instance.post_processing_module_prompt(self.connector_choice)

    def test_without_time_control(self):
        ret_dic= {}
        self.test_with_feedbacks(ret_dic)
        inner_break_flag = ret_dic["break_flag"]
        self.target_error_history = ret_dic["error_history"]
        self.target_traceback_history = ret_dic["tb_history"]
        return inner_break_flag


    def test_with_time_control(self):
        _, ret_dic = overtime_kill(self.test_with_feedbacks, (), 90, True)
        inner_break_flag = ret_dic["break_flag"]
        self.target_error_history = ret_dic["error_history"]
        self.target_traceback_history = ret_dic["tb_history"]
        return inner_break_flag

    def test_with_feedbacks (self, ret_dic):
        break_flag = False
        try:
            self.test_base()
            print(f"-----Code generation for the {self.version_choice+1}th version of {self.component_choice} of the {self.connector_choice}th connector succeeds!-----")
            break_flag = True
        except Exception as e:
            error = str(e)
            self.target_error_history.append(error)
            tb = str(traceback.format_exc())
            self.target_traceback_history.append(tb)
            print("Code fails unit tests")
            print("Errors:\n", error, tb)
        ret_dic["break_flag"] = break_flag
        ret_dic["error_history"] = self.target_error_history
        ret_dic["tb_history"] = self.target_traceback_history

    def test_base (self):
        gc.collect()

        if self.plans[self.connector_choice]["connector"][0]=="computer vision" or self.plans[self.connector_choice]["connector"][0]=="natural language processing":
            try:
                torch.cuda.empty_cache()
            except:
                import torch
                torch.cuda.empty_cache()
            time.sleep(1)

        importlib.reload(unit_testing)

        counter=0

        if self.plans[self.connector_choice]["connector"][0] == "traditional algorithms":
            self.test_sampling_count=1

        while counter < self.test_sampling_count:
            if self.plans[self.connector_choice]["connector"][0]=="computer vision" or self.plans[self.connector_choice]["connector"][0]=="natural language processing":
                configuration,fixed_dimensions = self.search_space_instance.random_search()
                self.test_instance.configure([self.connector_choice,self.component_choice,self.version_choice], configuration, fixed_dimensions,"simulated_data_0_0")
            elif self.plans[self.connector_choice]["connector"][0]=="traditional algorithms":
                self.test_instance.configure([self.connector_choice,self.component_choice,self.version_choice])
            self.test_instance.run()
            counter +=1

    def connector_search_space_formation(self,load_only=False):

        self.extractor_instance.configure("connector-specific", [0, None, None])
        self.extractor_instance.load_logs()

        if not load_only:
            raw_space = self.extractor_instance.run()
        else:
            self.extractor_instance.load_raw_search_space()
            raw_space = self.extractor_instance.raw_space

        self.search_space_instance.configure("connector-specific hyperparameters", connector_choice=0)
        self.search_space_instance.load_logs()
        if not load_only:
            self.search_space_instance.load_raw_space(raw_space=raw_space)
        else:
            self.search_space_instance.raw_search_space = raw_space
        self.search_space_instance.form_space()
        self.fixed_dimension_size=self.search_space_instance.extract_fixed_dimension_size()