from prompts import prompter
from base_modules.inqury import GPT
from base_modules.code_management import exception_handler
import keyword

class planner():

    basic_task_types_1 = ["traditional algorithms", 'natural language processing', 'computer vision', "audio signal processing","multi-modal"]
    basic_task_types_2 = ["binary classification", 'multi-class classification', 'multi-label classification', "multi-output regression",
                          'single-output regression', 'sequence-to-sequence', 'sequence generation', 'image generation', 'other']

    output_types = ["probability representation of class labels" , "integer representation of class labels" , "regression output"]


    def __init__(self,task_description,mode,key,model,planner_configuration,debug=True):
        self.task_description=task_description
        self.mode=mode
        self.planner_configuration=planner_configuration
        self.prompter_instance = prompter(self.planner_configuration)
        self.gpt_instance=GPT(key,model=model)
        self.debug=debug
        self.connector_set=[]
        self.connector_plan=None
        self.data_preparation_plans=[]
        self.modeling_plans=None
        self.loss_function_plans=None

    def run(self):
        self.prompter_instance.task_description_conversion(self.task_description)
        self.connector_plans_generation()
        print("-----Plans for the task type:-----")
        print(self.connector_plan)
        self.prompter_instance.connector_plan=self.connector_plan
        self.prompter_instance.connector_plans_answers_prompt()
        self.modeling_plans_generation()
        print("-----Plans for modeling:-----")
        print(self.modeling_plans)
        self.data_preparation_plans_generation()
        print("-----Plans for data preparation:-----")
        print(self.data_preparation_plans)

        self.append_plans()
        if self.debug:
            print("-----The full plan is:-----")
            print(self.connector_set)

    def append_plans(self):
        self.connector_set.append({"connector":self.connector_plan,"data preparation":self.data_preparation_plans,"modeling":self.modeling_plans})

    def plans_generation_base(self, type,raw_analysis=None):

        self.prompter_instance.reset()

        if type=="connector_analysis":
            print("-----Analyzing the task types-----")
            self.prompter_instance.connector_plans_analysis()

        elif type=="connector":
            print("-----Extracting task plans-----")
            self.prompter_instance.connector_plans_prompt(raw_analysis)

        elif type=="data preparation":
            print("-----Planning for data preparation-----")
            self.prompter_instance.data_preparation_plans_prompt()

        elif type=="modeling":
            print("-----Planning for modeling-----")
            self.prompter_instance.modeling_plans_prompt()

        answer = self.gpt_instance.send_inquiry({}, self.prompter_instance.prompt_message)

        return answer

    @exception_handler
    def connector_plans_generation(self):
        raw_analysis=self.plans_generation_base("connector_analysis")
        raw_answer=self.plans_generation_base("connector",raw_analysis)
        self.connector_plan=self.gpt_instance.extract_answer(raw_answer, "python object")
        self.connector_plans_tests(self.connector_plan)
        self.dimension_naming()

    def connector_plans_tests(self,answer):

        assert type(answer)==list, "answer is not a list"
        assert answer[0] in planner.basic_task_types_1, f"answer[0] is not one of {planner.basic_task_types_1}"
        assert answer[1] in planner.basic_task_types_2, f"answer[1] is not one of {planner.basic_task_types_2}"
        assert answer[2] in planner.output_types, f"answer[2] is not one of {planner.output_types}"


        if answer[0]=="computer vision" or answer[0]=="natural language processing":
            assert type(answer[3]) == str , "answer[3] is not a string"

            assert type(answer[4])==dict, "answer[4] is not a dictionary"

            for I_O in ["input","output"]:
                assert I_O in answer[4], f"answer[4] does not have key {I_O}"
                assert type(answer[4][I_O]) == list, f"answer[4][{I_O}] is not a list"
                assert len(answer[4][I_O]) >= 1, f"answer[4][{I_O}] does not have at least one tensor"

                for i in range(len(answer[4][I_O])):
                    for j in range(i + 1, len(answer[4][I_O])):
                        assert answer[4][I_O][i]['name'] != answer[4][I_O][j]['name'], f"answer[4][{I_O}][{i}]['name'] is the same as answer[4][{I_O}][{j}]['name']"
                if answer[0]=="natural language processing" and I_O=="input":
                    assert len(answer[4][I_O]) >= 2, f"answer[4][{I_O}] does not have at least two tensors"
                for tensor in answer[4][I_O]:
                    assert "name" in tensor, f"answer[4][{I_O}] does not have key 'name'"
                    if answer[0]=="computer vision":
                        assert "attention" not in tensor["name"], f"answer[4][{I_O}]['name'] contains the word 'attention'"
                    assert "shape" in tensor, f"answer[4][{I_O}] does not have key 'shape'"
                    assert type(tensor["shape"])==list, f"answer[4][{I_O}]['shape'] is not a list"
                    assert tensor["shape"][0]=="batch_size", f"answer[4][{I_O}]['shape'][0] is not 'batch_size'"
                    assert tensor["batch_size"]["fixed_or_variable"] == "variable", f"answer[4][{I_O}]['batch_size']['fixed_or_variable'] is not 'variable'"
                    assert len(tensor["shape"]) == len(set(tensor["shape"])), "answer[4][{I_O}]['shape'] has duplicate elements"
                    for dimension in tensor["shape"]:
                        assert dimension in tensor, f"answer[4][{I_O}]['shape'] does not have key {dimension}"
                        assert type(tensor[dimension])==dict, f"answer[4][{I_O}]['shape'][{dimension}] is not a dictionary"
                        assert "meaning" in tensor[dimension], f"answer[4][{I_O}]['shape'][{dimension}] does not have key 'meaning'"
                        assert type(tensor[dimension]["meaning"])==str, f"answer[4][{I_O}]['shape'][{dimension}]['meaning'] is not a string"
                        assert "fixed_or_variable" in tensor[dimension], f"answer[4][{I_O}]['shape'][{dimension}] does not have key 'fixed_or_variable'"
                        assert tensor[dimension]["fixed_or_variable"] in ["fixed","variable"], f"answer[4][{I_O}]['shape'][{dimension}]['fixed_or_variable'] is not 'fixed' or 'variable'"
                        if tensor[dimension]["fixed_or_variable"] == "variable":
                            dimension_raw_name= dimension.replace(" ", "_")
                            assert planner.raw_name_validity_check(dimension_raw_name), f"the name of the dimension is not a valid variable name"


    @staticmethod
    def raw_name_validity_check(name):
        # Check if the name is a Python keyword
        if name in keyword.kwlist:
            return False
        # Check if the name starts with a letter or underscore
        if not name[0].isalpha() and name[0] != '_':
            return False
        # Check if all characters in the name are valid (letters, digits, or underscores)
        for char in name:
            if not (char.isalnum() or char == '_'):
                return False
        # If all checks pass, it's a valid variable name
        return True



    def dimension_naming(self):
        if self.connector_plan[0]=="traditional algorithms":
            pass
        else:
            raw_name_dictionary = {}

            # name the batch size dimensions first
            for I_O in ["input", "output"]:
                for tensor in self.connector_plan[4][I_O]:
                    tensor["batch_size"].update({"variable_name": "batch_size"})
                    all_raw_names = list(tensor.keys())
                    filtered_raw_names = [item for item in all_raw_names if item not in ["name", "shape", "batch_size"]]
                    raw_name_dictionary[tensor["name"]] = filtered_raw_names

            # clean the raw names of each dimension
            for key, values in raw_name_dictionary.items():
                raw_name_dictionary[key] = [value.replace(" ", "_") for value in values]

            # renaming non-unique dimensions
            all_values = sum(raw_name_dictionary.values(), [])
            non_unique = {string for string in all_values if all_values.count(string) > 1}
            for key, values in raw_name_dictionary.items():
                raw_name_dictionary[key] = [f"{key}_{value}" if value in non_unique else value for value in values]

            # put the names in the connector plan
            for I_O in ["input", "output"]:
                for tensor_idx, tensor in enumerate(self.connector_plan[4][I_O]):
                    for dimension in self.connector_plan[4][I_O][tensor_idx]["shape"]:
                        if dimension != "batch_size" and self.connector_plan[4][I_O][tensor_idx][dimension]["fixed_or_variable"] == "variable":
                            superstring = None

                            for s in raw_name_dictionary[tensor["name"]]:
                                if dimension in s:
                                    if superstring is None or len(s) < len(superstring):
                                        superstring = s
                            superstring = superstring.replace(" ", "_")
                            assert planner.raw_name_validity_check(superstring), f"the dimension name {superstring} is not a valid variable name"
                            self.connector_plan[4][I_O][tensor_idx][dimension]["variable_name"]=superstring


    @exception_handler
    def modeling_plans_generation(self):
        raw_answer=self.plans_generation_base("modeling")
        self.modeling_plans=self.gpt_instance.extract_answer(raw_answer, "python object")
        self.modeling_plans_tests(self.modeling_plans)

    def modeling_plans_tests(self,answer):

        assert type(answer)==list, "answer is not a list"
        assert len(answer)>=self.planner_configuration["maximum_models"], f"answer is not a list of length {self.planner_configuration['maximum_models']} or more"
        for choice in answer:
            assert isinstance(choice, dict), f"a choice in the list is not a dictionary"
            assert all(isinstance(key, str) for key in choice), f"a key in a choice is not a string"

            choice_value = list(choice.values())[0]
            assert type(choice_value) == str or type(
                choice_value) == dict, f"the value of a choice is not a string or a dictionary"

            if type(choice_value) == dict:

                for key, value in choice_value.items():
                    assert isinstance(key, str), f"Key '{key}' of a choice is not a string"
                    assert isinstance(value, str), f"Value '{value}' for key '{key}' of a choice is not a string"

                assert "reason" in choice_value, f"the value of a choice does not have key 'reason'"
                for individual_model in list(choice_value.keys())[1:]:
                    assert individual_model in list(choice.keys())[0], f"an individual model {individual_model} in a" \
                                                                       f" combination is not present in the combination {list(choice.keys())[0]}"


    @exception_handler
    def data_preparation_plans_generation(self):
        counter=0
        self.prompter_instance.generated_data_preparation_plans=""
        while counter < self.planner_configuration["maximum_data_preparations"]:
            if counter >=1:
                self.prompter_instance.load_generated_data_preparation_plans(self.data_preparation_plans)
            raw_answer=self.plans_generation_base("data preparation")
            single_plan=self.gpt_instance.extract_answer(raw_answer, "python object")
            self.data_preparation_plans_tests(single_plan)
            self.data_preparation_plans.append(single_plan)
            counter+=1


    def data_preparation_plans_tests(self,answer):

        assert type(answer)==dict, "answer is not a dictionary"
        assert all(isinstance(key, str) for key in answer), f"a key in answer is not a string"
        assert all(isinstance(value, str) for value in answer.values()), f"a value in answer is not a string"

        answer_keys=list(answer.keys())
        if self.connector_plan[0]=="traditional algorithms":
            assert len(answer) == 5 , f"answer does not have 5 keys"
            assert answer_keys[0]=="data loading", "the first key in answer is not 'data loading'"
            assert answer_keys[1]=="data cleaning", "the second key in answer is not 'data cleaning'"
            assert answer_keys[2]=="encoding", "the first key in answer is not 'encoding'"
            assert answer_keys[3]=="data preprocessing", "the second key in answer is not 'data preprocessing'"
            assert answer_keys[4]=="feature engineering", "the third key in answer is not 'feature engineering'"
        if self.connector_plan[0]=="computer vision":
            assert len(answer) == 3 , f"a data preparation plan does not have 3 keys"
            assert answer_keys[0]=="resizing", "the first key in answer is not 'resizing'"
            assert answer_keys[1]=="data augmentation", "the second key in answer is not 'data augmentation'"
            assert answer_keys[2]=="normalization", "the third key in answer is not 'normalization'"
        if self.connector_plan[0]=="natural language processing":
            assert len(answer) == 4 , f"answer does not have 4 keys"
            assert answer_keys[0]=="data loading", "the first key in answer is not 'data loading'"
            assert answer_keys[1]=="text cleaning", "the first key in answer is not 'text cleaning'"
            assert answer_keys[2]=="data augmentation", "the second key in answer is not 'data augmentation'"
            assert answer_keys[3]=="task-specific feature engineering", "the third key in answer is not 'task-specific feature engineering'"
