from prompts import prompter
from base_modules.inqury import GPT
from base_modules.code_management import exception_handler
import random
from copy import deepcopy
from log import log_class
import math

class extractor():
    def __init__(self,task_description,key,model,workspace, debug=False):
        self.task_description = task_description
        self.gpt_instance = GPT(key, model=model)
        self.debug=debug

        self.information_instance=log_class(workspace)

    def configure(self,mode,choice):
        self.mode = mode

        self.connector_choice = choice[0]
        self.component_choice = choice[1]
        self.version_choice = choice[2]


    def load_logs(self):
        self.plans=self.information_instance.read()["plans"]
        self.prompter_instance = prompter(generated_connector_plan=self.plans)

    def load_raw_search_space(self):
        self.raw_space=self.information_instance.read()["raw_search_space"]
        # first_key = next(iter(self.raw_space))
        # self.raw_space=self.raw_space[first_key]

        # marker: fix here

    def run(self,input_code=None):
        if self.mode=="connector-specific":
            self.prompter_instance.task_description_conversion(self.task_description)
            if len(self.plans[self.connector_choice]["connector"])<6:
                self.isomorphic_dimensions()
            range_of_hyperparameters=self.connector_extraction()
            return range_of_hyperparameters
        elif self.mode=="version-specific":
            output_code=self.version_extraction(input_code)
            space=self.version_extraction_space(output_code)
            return output_code,space

    @staticmethod
    def unify_isomorphic_dimensions(raw_hyperparameters , isomorphic_dimensions) :
        conversion_values = []
        for dictionary in isomorphic_dimensions :
            for key in dictionary :
                conversion_values += dictionary[key]

        new_hyperparameters = deepcopy(raw_hyperparameters)
        for I_O in ["input" , "output"] :
            for tensor_idx , tensor in enumerate(raw_hyperparameters[I_O]) :
                for dimension in tensor :
                    if dimension in conversion_values :
                        for dictionary in isomorphic_dimensions :
                            for key in dictionary :
                                if dimension in dictionary[key] :
                                    new_name = key
                        new_hyperparameters[I_O][tensor_idx][new_name] = new_hyperparameters[I_O][tensor_idx].pop(
                            dimension)
        return new_hyperparameters


    def retrieve_generated_modeling_modules_by_name(self):
        generated_modeling_modules_list=self.plans[self.connector_choice]["generated_modeling_modules"]
        planned_modeling_modules_list=self.plans[self.connector_choice]["modeling"]
        generated_modeling_modules_by_name = [list(planned_modeling_modules_list[i].keys())[0] for i in generated_modeling_modules_list]
        return generated_modeling_modules_by_name

    @exception_handler
    def inter_version_hyperparameters(self):

        generated_modeling_modules_by_name=self.retrieve_generated_modeling_modules_by_name()
        self.prompter_instance.reset()
        self.prompter_instance.inter_version_hyperparameters_prompt(self.connector_choice,generated_modeling_modules_by_name)

        if self.debug:
            print("prompt message for inter-version hyperparameters:")
            for message in self.prompter_instance.prompt_message:
                print(message["content"])

        raw_answer = self.gpt_instance.send_inquiry({}, self.prompter_instance.prompt_message)
        answer = self.gpt_instance.extract_answer(raw_answer, "python object")
        if self.debug:
            print("inter-version hyperparameters:",answer)
        self.inter_version_hyperparameters_test(answer)
        self.plans[self.connector_choice]["inter_version_hyperparameters"]=answer
        self.information_instance.update_plans(self.plans)

        # marker: begin write here

        # the search space should be a mapping from modeling plan to the value (code segment)
        # 1. load the modeling plan
        # 2. generate the value
        # 3. update the composite search space

        # for code generation:
        # 1. modify the code
        # 2. test the code

        # for optimization:
        # 1. use the value upon designation

    @exception_handler
    def isomorphic_dimensions(self):
        self.prompter_instance.reset()
        self.prompter_instance.connector_isomorphic_dimensions_prompt(self.connector_choice)
        print("-----Analyzing the isomorphic dimensions-----")
        raw_answer = self.gpt_instance.send_inquiry({}, self.prompter_instance.prompt_message)
        answer = self.gpt_instance.extract_answer(raw_answer, "python object")
        self.isomorphic_dimensions_test(answer)
        print("-----Identified isomorphic dimensions:-----")
        print(answer)
        self.plans[self.connector_choice]["connector"].append(answer)
        self.information_instance.update_plans(self.plans)


    def isomorphic_dimensions_test(self,answer):
        pass

    @exception_handler
    def connector_extraction(self):
        self.prompter_instance.reset()
        self.prompter_instance.connector_hyperparameter_extraction(self.connector_choice)
        print("-----Designating the ranges of hyperparameters-----")
        raw_answer = self.gpt_instance.send_inquiry({}, self.prompter_instance.prompt_message)
        answer = self.gpt_instance.extract_answer(raw_answer, "python object")
        self.connector_extraction_test(answer)
        print("-----Identified ranges of hyperparameters:-----")
        print(answer)

        if self.plans[self.connector_choice]["connector"][2]=="integer representation of class labels":
            self.plans[self.connector_choice]["connector"].append({"num_classes":answer["num_classes"]})
            self.information_instance.update_plans(self.plans)


        answer= self.connector_replace_names(answer)
        # if self.debug:
        #     print("-----Connector hyperparameter extraction after name replacement:-----")
        #     print(answer)
        answer=extractor.unify_isomorphic_dimensions(answer,self.plans[self.connector_choice]["connector"][5])
        # if self.debug:
        #     print("-----Connector hyperparameter extraction after isomorphic dimensions unification:-----")
        #     print(answer)
        return answer

    def connector_replace_names(self,answer):
        for I_O in ["input","output"]:
            tensor_group=self.plans[self.connector_choice]["connector"][4][I_O]
            for tensor_idx, tensor in enumerate(tensor_group):
                for hyperparameter in tensor["shape"]:
                    if tensor_group[tensor_idx][hyperparameter]["fixed_or_variable"]=="variable":
                        old_key=hyperparameter
                        new_key=tensor_group[tensor_idx][hyperparameter]["variable_name"]
                        value = answer[I_O][tensor_idx].pop(old_key)
                        answer[I_O][tensor_idx][new_key] = value
        return answer

    def connector_extraction_test(self,extracted_answer):

        assert type(extracted_answer)==dict, f"extracted_answer is type {type(extracted_answer)} instead of type dict"
        assert "input" in extracted_answer, f"extracted_answer has keys {extracted_answer.keys()} but does not have the key 'input'"
        if self.plans[self.connector_choice]["connector"][2] == "integer representation of class labels" :
            assert "num_classes" in extracted_answer, f"extracted_answer, which is {extracted_answer}, does not have the key 'num_classes'"
            assert type(extracted_answer["num_classes"])==int and extracted_answer["num_classes"]>1, f"extracted_answer['num_classes'] is {extracted_answer['num_classes']} instead of an integer greater than 1"
        assert "output" in extracted_answer, f"extracted_answer has keys {extracted_answer.keys()} but does not have the key 'output'"
        assert type(extracted_answer["input"])==list, f"extracted_answer['input'] is type {type(extracted_answer['input'])} instead of type list"
        assert type(extracted_answer["output"])==list, f"extracted_answer['output'] is type {type(extracted_answer['output'])} instead of type list"


        for I_O in ["input","output"]:
            for tensor_idx, tensor in enumerate(extracted_answer[I_O]):
                assert type(tensor)==dict, f"extracted_answer[{I_O}][{tensor_idx}] is type {type(tensor)} instead of type dict"
                assert tensor_idx <= len(self.plans[self.connector_choice]["connector"][4][I_O])-1, f"extracted_answer[{I_O}][{tensor_idx}] is {tensor_idx} which is greater than the number of tensors in the plan"

                for hyperparameter in tensor:
                    assert type(tensor[hyperparameter])==list, f"extracted_answer[{I_O}][{tensor_idx}][{hyperparameter}] is type {type(tensor[hyperparameter])} instead of type list"
                    if self.plans[self.connector_choice]["connector"][4][I_O][tensor_idx][hyperparameter]["fixed_or_variable"]=="variable":
                        assert len(tensor[hyperparameter])==2, f"extracted_answer[{I_O}][{tensor}][{hyperparameter}] is {tensor[hyperparameter]}, which is not a list of length 2"
                        assert type(tensor[hyperparameter][0])==int, f"extracted_answer[{I_O}][{tensor}][0] is {tensor[hyperparameter][0]}, which is not an integer"
                        assert type(tensor[hyperparameter][1])==int, f"extracted_answer[{I_O}][{tensor}][1] is {tensor[hyperparameter][1]}, which is not an integer"
                        assert tensor[hyperparameter][0] <= tensor[hyperparameter][1], f"extracted_answer[{I_O}][{tensor}][0] is {tensor[hyperparameter][0]} and extracted_answer[{I_O}][{tensor}][1] is {tensor[hyperparameter][1]}, the first element is greater than the second element"
                    elif self.plans[self.connector_choice]["connector"][4][I_O][tensor_idx][hyperparameter]["fixed_or_variable"]=="fixed":
                        assert len(tensor[hyperparameter])==1, f"extracted_answer[{I_O}][{tensor}][{hyperparameter}] is {tensor[hyperparameter]}, which is not a list of length 1"
                        assert type(tensor[hyperparameter][0])==int, f"extracted_answer[{I_O}][{tensor}][0] is {tensor[hyperparameter][0]}, which is not an integer"

        if self.plans[self.connector_choice]["connector"][2]=="integer representation of class labels":
            assert "num_classes" in extracted_answer, f"extracted_answer, which is {extracted_answer}, does not have the key 'num_classes'"
            assert type(extracted_answer["num_classes"])==int, f"extracted_answer['num_classes'] is {extracted_answer['num_classes']} instead of an integer"

    @exception_handler
    def version_extraction_code(self,input_code):
        self.prompter_instance.version_hyperparameter_extraction_code(self.connector_choice,input_code)
        raw_output_code = self.gpt_instance.send_inquiry({}, self.prompter_instance.prompt_message)
        extracted_output_code = self.gpt_instance.extract_answer(raw_output_code, "python code")
        self.version_extraction_code_test(extracted_output_code)
        return extracted_output_code

    def version_extraction_code_test(self,extracted_output_code):
        pass

    @exception_handler
    def version_extraction_space(self,code):
        self.prompter_instance.version_hyperparameter_extraction_space(self.connector_choice,code)
        raw_answer = self.gpt_instance.send_inquiry({}, self.prompter_instance.prompt_message)
        answer = self.gpt_instance.extract_answer(raw_answer, "python object")
        self.version_extraction_space_test(answer)
        return answer

    def version_extraction_space_test(self,extracted_answer):
        pass
    # marker: fill in the functions

class search_space():
    def __init__(self, workspace=None, debug=False):
        self.debug=debug
        self.workspace=workspace
        self.logger_instance=log_class(self.workspace)


    def configure(self,mode,connector_choice=None,filtered_modeling_modules=None,combination_choice=None, maximum_modeling_modules=None, maximum_data_preparation_modules=None):
        self.connector_choice=connector_choice
        self.combination_choice=combination_choice
        self.mode=mode

        self.filtered_modeling_modules=filtered_modeling_modules
        self.maximum_modeling_modules=maximum_modeling_modules
        self.maximum_data_preparation_modules=maximum_data_preparation_modules

    def load_logs(self):
        self.logs=self.logger_instance.read()

        if "raw_search_space" in self.logs:
            self.raw_search_space=self.logs["raw_search_space"]
        else:
            self.raw_search_space={}
        if "composite_search_space" in self.logs:
            self.composite_search_space=self.logs["composite_search_space"]
        else:
            self.composite_search_space={}

    def load_raw_space(self,raw_space=None):
        if self.mode=="connector-specific hyperparameters":
            self.load_raw_space_connector(raw_space)

    def form_space(self):
        if self.mode=="connector-specific hyperparameters":
            self.form_connector_hyperparameter_space()
        elif self.mode=="CCF":
            self.form_finetuning_space()
            self.form_CCF_space()
        elif self.mode=="combination":
            self.form_combination_space()


    def load_raw_space_connector(self,raw_space):
        for I_O in ["input","output"]:
            for tensor_index, tensor in enumerate(raw_space[I_O]):
                for hyperparameter_index, hyperparameter in enumerate(tensor):
                    index=("connector-specific hyperparameters",self.connector_choice,I_O,tensor_index,hyperparameter_index)
                    entry={hyperparameter:{"range":raw_space[I_O][tensor_index][hyperparameter],"type":"int"}}
                    self.raw_search_space[index]=entry
        if self.debug:
            print("Raw search space:")
            print(self.raw_search_space)

        self.logger_instance.update_raw_search_space(self.raw_search_space)

    def form_connector_hyperparameter_space(self):

        filtered_keys = [key for key in self.raw_search_space.keys() if
                         isinstance(key, tuple) and len(key) == 5 and key[0] == "connector-specific hyperparameters" and key[1] == self.connector_choice]

        space={}
        input_tensors_keys=[key for key in filtered_keys if key[2]=="input"]
        output_tensors_keys=[key for key in filtered_keys if key[2]=="output"]
        for tensor_key in input_tensors_keys:
            for hyperparameter_key in [key for key in filtered_keys if key[2]=="input" and key[3]==tensor_key[3]]:
                hyperparameter_name=list(self.raw_search_space[hyperparameter_key].keys())[0]
                space[hyperparameter_name]=self.raw_search_space[hyperparameter_key][list(self.raw_search_space[hyperparameter_key].keys())[0]]["range"]
        for tensor_key in output_tensors_keys:
            for hyperparameter_key in [key for key in filtered_keys if key[2]=="output" and key[3]==tensor_key[3]]:
                hyperparameter_name=list(self.raw_search_space[hyperparameter_key].keys())[0]
                space[hyperparameter_name]=self.raw_search_space[hyperparameter_key][list(self.raw_search_space[hyperparameter_key].keys())[0]]["range"]

        try:
            self.composite_search_space["connector-specific hyperparameters"][self.connector_choice]=space
        except:
            self.composite_search_space["connector-specific hyperparameters"]={self.connector_choice:space}

        if self.debug:
            print("formed space")
            print(space)
            print("collection of space")
            print(self.composite_search_space)

        self.logger_instance.update_composite_search_space(self.composite_search_space)

        return space

    def form_combination_space(self):
        try:
            self.composite_search_space["combination"][self.connector_choice]={}
        except:
            self.composite_search_space["combination"] = {}
            self.composite_search_space["combination"][self.connector_choice]={}

        self.composite_search_space["combination"][self.connector_choice]["data_preparation"] = \
        list(range(self.maximum_data_preparation_modules))
        self.composite_search_space["combination"][self.connector_choice][
            "modeling"] = self.logs["plans"][self.connector_choice]["generated_modeling_modules"]

        self.logger_instance.update_composite_search_space(self.composite_search_space)

    def form_finetuning_space(self):

        self.universal_finetuning_space={}
        self.universal_finetuning_space["batch_size"]=[2,4]
        self.universal_finetuning_space["learning_rate"]=[1e-5,1e-1]
        self.universal_finetuning_space["weight_decay"]=[1e-4,1e-2]
        self.universal_finetuning_space["momentum"]=[0.01,0.99]
        self.universal_finetuning_space["optimizer"]=["adam","sgd","adamw"]
        self.universal_finetuning_space["scheduler"]=["plateau","cosine"]

    def form_CCF_space(self):


        try:
            self.composite_search_space["CCF"][self.connector_choice]={}
        except:
            self.composite_search_space["CCF"] = {}
            self.composite_search_space["CCF"][self.connector_choice]={}


        self.composite_search_space["CCF"][self.connector_choice]["combination"]={}

        if self.filtered_modeling_modules is not None:
            self.composite_search_space["CCF"][self.connector_choice]["combination"]["data_preparation"]=list(range(self.maximum_data_preparation_modules))
            self.composite_search_space["CCF"][self.connector_choice]["combination"]["modeling"]=self.filtered_modeling_modules
        else:
            self.composite_search_space["CCF"][self.connector_choice]["combination"]["data_preparation"]=list(range(self.maximum_modeling_modules))
            self.composite_search_space["CCF"][self.connector_choice]["combination"]["modeling"]=list(range(self.maximum_data_preparation_modules))

        self.composite_search_space["CCF"][self.connector_choice]["connector-specific hyperparameters"]={}

        connector_specific_hyperparameters_space=self.logs["composite_search_space"]["connector-specific hyperparameters"][self.connector_choice]
        for dimension in connector_specific_hyperparameters_space:
            if dimension != "batch_size":
                self.composite_search_space["CCF"][self.connector_choice]["connector-specific hyperparameters"][dimension]=connector_specific_hyperparameters_space[dimension]

        self.composite_search_space["CCF"][self.connector_choice]["universal_finetuning"]=self.universal_finetuning_space

        self.logger_instance.update_composite_search_space(self.composite_search_space)


    def extract_fixed_dimension_size(self):

        fixed_dimension_size = {}
        matching_keys = []
        for key in self.raw_search_space.keys():
            if key[0] == 'connector-specific hyperparameters' and key[1] == self.connector_choice:
                matching_keys.append(key)

        for key in matching_keys:
            name=list(self.raw_search_space[key].keys())[0]
            rang=list(self.raw_search_space[key].values())[0]["range"]
            if len(rang) == 1:
                fixed_dimension_size[name]=rang[0]
        return fixed_dimension_size

    @staticmethod
    def boundary_selection_int(space):
        space=space.copy()
        if isinstance(space, list) and len(space) == 2:
            lower_bound, upper_bound = space
            return random.randint(lower_bound, upper_bound)
        elif isinstance(space, list) and len(space) == 1:
            return space[0]
        elif isinstance(space, dict):
            for key, value in space.items():
                space[key] = search_space.boundary_selection_int(value)
        return space



    @staticmethod
    def boundary_selection_float(space, log=False):
        space=space.copy()
        if isinstance(space, list) and len(space) == 2:
            lower_bound, upper_bound = space
            if log:
                log_min = math.log(lower_bound)
                log_max = math.log(upper_bound)
                return math.exp(random.uniform(log_min , log_max))
            else:
                return random.uniform(lower_bound, upper_bound)
        elif isinstance(space, list) and len(space) == 1:
            return space[0]
        elif isinstance(space, dict):
            for key, value in space.items():
                space[key] = search_space.boundary_selection_float(value,log)
        return space


    @staticmethod
    def categorical_selection(space):
        space=space.copy()
        result = { }
        for key , value_list in space.items() :
            result[key] = random.choice(value_list)
        return result

    def random_search(self):
        if self.mode=="connector-specific hyperparameters":
            space=self.composite_search_space["connector-specific hyperparameters"][self.connector_choice]

            fixed_dimension=[]
            for key, value in space.items():
                if isinstance(value, list) and len(value) == 1:
                    fixed_dimension.append(key)
            configuration=search_space.boundary_selection_int(space)
            if self.debug:
                print("configuration from random search:", configuration, fixed_dimension)
            return configuration, fixed_dimension
        if self.mode=="CCF":
            space=self.composite_search_space["CCF"][self.connector_choice]

            combination_configuration=search_space.categorical_selection(space["combination"])

            int_numerical_finetune_dimensions=["batch_size"]
            float_uniform_numerical_finetune_dimensions=["momentum"]
            float_log_numerical_finetune_dimensions=["learning_rate","weight_decay"]
            categorical_finetune_dimensions=["optimizer","scheduler"]

            finetuning_space=space["universal_finetuning"]

            int_numerical_finetune_configuration=search_space.boundary_selection_int({ k: finetuning_space[k] for k in int_numerical_finetune_dimensions if k in finetuning_space })
            float_uniform_finetune_configuration=search_space.boundary_selection_float({ k: finetuning_space[k] for k in float_uniform_numerical_finetune_dimensions if k in finetuning_space})
            float_log_finetune_configuration=search_space.boundary_selection_float({ k: finetuning_space[k] for k in float_log_numerical_finetune_dimensions if k in finetuning_space},log=True)
            categorical_finetune_configuration=search_space.categorical_selection({k: finetuning_space[k] for k in categorical_finetune_dimensions if k in finetuning_space})

            finetuning_configuration = { }
            for dictionary in [int_numerical_finetune_configuration, float_uniform_finetune_configuration, float_log_finetune_configuration, categorical_finetune_configuration]:
                finetuning_configuration.update(dictionary)

            connector_specific_hyperparameters_configuration=search_space.boundary_selection_int(space["connector-specific hyperparameters"])

            return combination_configuration, connector_specific_hyperparameters_configuration, finetuning_configuration
        if self.mode=="combination":
            space=self.composite_search_space["combination"][self.connector_choice]

            combination_configuration=search_space.categorical_selection(space)

            return combination_configuration