import importlib
import numpy as np
from math import isclose
import random

import torch

from utilities import locate_dimension_in_connector
from sklearn.model_selection import train_test_split


class tester():

    range_of_batch_size=[2,32]

    input_array_size=100


    def __init__(self,workspace,plans,train_val_ratio,tolerance=0.03,debug=False):

        self.workspace=workspace
        self.plans=plans
        self.train_val_ratio=train_val_ratio
        self.tolerance=tolerance
        self.debug=debug

    def configure(self, choice, configuration=None, fixed_dimensions=None,simulated_data_file_name=None):
        self.connector_choice = choice[0]
        self.basic_task_types_1=self.plans[self.connector_choice]["connector"][0]
        self.basic_task_types_2=self.plans[self.connector_choice]["connector"][1]
        self.output_types=self.plans[self.connector_choice]["connector"][2]
        self.component_choice = choice[1]
        self.version_choice = choice[2]
        self.configuration=configuration
        self.fixed_dimensions=fixed_dimensions

        if self.basic_task_types_1 == "computer vision" or self.basic_task_types_1 == "natural language processing":
            self.isomorphic_dimensions = self.plans[self.connector_choice]["connector"][5]

        self.simulated_data_file_name=simulated_data_file_name
        self.file_name = f"{self.component_choice}_{self.connector_choice}_{self.version_choice}"

    def run(self):
        if self.component_choice=="data_preparation":
            self.data_preparation()
        elif self.component_choice=="modeling":
            self.modeling()
        elif self.component_choice=="loss_function":
            self.loss_funtion()
        elif self.component_choice=="post_processing":
            self.post_processing()
        elif self.component_choice=="simulated_data":
            self.simulated_data()


    def tabular_data_test(self,data,dimensions,name):
        assert isinstance(data, np.ndarray), f"{name} is not a numpy array but {type(data)}"
        assert data.ndim == dimensions, f"{name} is not a {dimensions}D numpy array"
        nan_mask = np.isnan(data)
        assert nan_mask.sum() == 0, f"{name} contains NaNs"


    def retreive_configuration(self,filter_fixed_dimensions=False):
        pass


    def DL_data_test(self,tuple_of_tensors, input_or_output):

        data_format=self.plans[self.connector_choice]["connector"][4]

        dimension_names = []
        for I_O in data_format:
            for tensor_idx, tensor in enumerate(data_format[I_O]):
                for dimension_idx, dimension in enumerate(tensor["shape"]):
                    if tensor[dimension]["fixed_or_variable"]=="fixed":
                        dimension_names.append(dimension)
                    else:
                        dimension_names.append(tensor[dimension]["variable_name"])

        dimension_values ={}

        for n in dimension_names:
            if n in self.configuration:
                dimension_values[n]=self.configuration[n]
            else:
                for isomorphic_group in self.isomorphic_dimensions :
                    for key in isomorphic_group :
                        if n in isomorphic_group[key] :
                            dimension_values[n]=self.configuration[key]

        filtered_dimensions = {}

        for dimension_name in dimension_names:
            I_O, tensor_idx, dimension_idx = locate_dimension_in_connector(self.plans[self.connector_choice]["connector"][4],
                                                                           dimension_name)
            for n in input_or_output:
                if I_O==n:
                    try:
                        filtered_dimensions[n][tensor_idx][dimension_idx] = {dimension_name: dimension_values[dimension_name]}
                    except:
                        try:
                            filtered_dimensions[n][tensor_idx] = {}
                            filtered_dimensions[n][tensor_idx][dimension_idx] = {dimension_name: dimension_values[dimension_name]}
                        except:
                            filtered_dimensions[n] = {}
                            filtered_dimensions[n][tensor_idx] = {}
                            filtered_dimensions[n][tensor_idx][dimension_idx] = {dimension_name: dimension_values[dimension_name]}

        for n in input_or_output:
            if n in filtered_dimensions:
                target_tensor_group=self.plans[self.connector_choice]["connector"][4][n]
                for tensor_idx in filtered_dimensions[n]:
                    if n == "input" or len(input_or_output)==1:
                        global_tensor_idx = tensor_idx
                    else:
                        global_tensor_idx = tensor_idx + len(self.plans[self.connector_choice]["connector"][4]["input"])

                    tensor = tuple_of_tensors[global_tensor_idx]
                    assert type(tensor) == torch.Tensor, f"The {tensor_idx+1}th tensor in {n} data is not a torch.tensor but {type(tensor)}"

                    if self.output_types=="integer representation of class labels":
                        if n=="output" and global_tensor_idx==0:
                            assert tensor.ndim == len(target_tensor_group[tensor_idx]["shape"])-1 , f"The {tensor_idx}th tensor in {n} data has {tensor.ndim} dimensions but should have {len(target_tensor_group[tensor_idx]['shape'])-1} dimensions"
                            num_class_dimension=list(filtered_dimensions[n][tensor_idx].keys())[-1]
                            other_dimensions = list(filtered_dimensions[n][tensor_idx].keys())[:-1]
                            for dimension_idx in other_dimensions:
                                assert tensor.shape[dimension_idx] == list(filtered_dimensions[n][tensor_idx][dimension_idx].values())[0], f"The size of the {dimension_idx+1}th dimension of the {tensor_idx}th tensor in {n} data is {tensor.shape[dimension_idx]} but should be {list(filtered_dimensions[n][tensor_idx][dimension_idx].values())[0]}"
                            assert tensor.shape[num_class_dimension] == self.plans[self.connector_choice]["connector"][6]["num_classes"], f"The size of the {num_class_dimension+1}th dimension of the {tensor_idx}th tensor in {n} data is {tensor.shape[num_class_dimension]} but should be {self.plans[self.connector_choice]['connector'][6]['num_classes']}"
                    else:
                        assert tensor.ndim==len(target_tensor_group[tensor_idx]["shape"]), f"The {tensor_idx}th tensor in {n} data has {tensor.ndim} dimensions but should have {len(target_tensor_group[tensor_idx]['shape'])} dimensions"
                        for dimension_idx in filtered_dimensions[n][tensor_idx]:
                            assert tensor.shape[dimension_idx] == list(filtered_dimensions[n][tensor_idx][dimension_idx].values())[0], f"The size of the {dimension_idx+1}th dimension of the {tensor_idx}th tensor in {n} data is {tensor.shape[dimension_idx]} but should be {list(filtered_dimensions[n][tensor_idx][dimension_idx].values())[0]}"

    def variable_dimensions_arguments_string_generator(self):
        variable_dimensions_arguments_string = ""
        for n in self.configuration:
            if n not in self.fixed_dimensions:
                exec(f"{n}={self.configuration[n]}", globals())
                variable_dimensions_arguments_string += f"{n}={self.configuration[n]},"
        return variable_dimensions_arguments_string

    def load_file(self):
        exec(f"from {self.workspace} import {self.file_name}", globals())
        exec(f"importlib.reload({self.file_name})", globals())

    def simulated_data(self):

        self.load_file()

        variable_dimensions_arguments_string=self.variable_dimensions_arguments_string_generator()


        local_namespace = {}
        exec(
            f"simulated_data = {self.file_name}.generate_data({variable_dimensions_arguments_string})",
            globals(),local_namespace)
        globals().update(local_namespace)
        simulated_data=local_namespace["simulated_data"]

        if type(simulated_data) != tuple:
            simulated_data = (simulated_data,)

        self.DL_data_test(simulated_data,["input","output"])

        print("A unit test for simulated data is passed!")


    def simulated_data_generation(self,variable_dimensions_arguments_string=None):

        if self.basic_task_types_1=="traditional algorithms":

            n_samples=100
            n_features=5
            n_output=3

            simulated_input = np.random.randn(n_samples , n_features)

            if self.basic_task_types_2=="single-output regression":
                simulated_output=np.random.randn(n_samples)
            elif self.basic_task_types_2=="multi-output regression":
                simulated_output = np.random.randn(n_samples,n_output)
            elif self.basic_task_types_2=="binary classification":
                if self.output_types=="integer representation of class labels":
                    simulated_output = np.random.randint(2, size=n_samples)
                elif self.output_types=="probability representation of class labels":
                    simulated_output = np.random.rand(n_samples,2)
            elif self.basic_task_types_2=="multi-class classification":
                if self.output_types=="integer representation of class labels":
                    simulated_output = np.random.randint(n_output, size=n_samples)
                elif self.output_types=="probability representation of class labels":
                    simulated_output = np.random.rand(n_samples,n_output)
            elif self.basic_task_types_2=="multi-label classification":
                if self.output_types=="integer representation of class labels":
                    simulated_output = np.random.randint(2, size=(n_samples,n_output))
                elif self.output_types=="regression output":
                    simulated_output = np.random.rand(n_samples,n_output)

        elif self.basic_task_types_1=="natural language processing" or self.basic_task_types_1=="computer vision":
            exec(f"from {self.workspace} import {self.simulated_data_file_name}", globals())
            exec(f"importlib.reload({self.simulated_data_file_name})", globals())

            local_namespace = {}
            exec(
                f"simulated_data = {self.simulated_data_file_name}.generate_data({variable_dimensions_arguments_string})",
                globals(), local_namespace)
            globals().update(local_namespace)
            simulated_data = local_namespace["simulated_data"]

            tensor_group = self.plans[self.connector_choice]["connector"][4]
            input_tensor_count = len(tensor_group["input"])


            if self.output_types=="integer representation of class labels":
                simulated_input = simulated_data[:input_tensor_count]
                simulated_ground_truth_output = simulated_data[input_tensor_count:input_tensor_count+1]
                simulated_predicted_output = simulated_data[input_tensor_count+1:]
            else:
                simulated_input = simulated_data[:input_tensor_count]
                simulated_output = simulated_data[input_tensor_count:]

        if self.output_types=="integer representation of class labels" and \
                ((self.basic_task_types_1=="natural language processing") or (self.basic_task_types_1=="computer vision")):
            return simulated_input, simulated_ground_truth_output, simulated_predicted_output
        else:
            return simulated_input, simulated_output

    def modeling(self):

        if self.basic_task_types_1 == "traditional algorithms" :
            simulated_input , simulated_output = self.simulated_data_generation()
            X_train,  X_test, y_train, y_test = train_test_split(simulated_input, simulated_output, test_size=self.train_val_ratio)


            self.load_file()


            local_namespace = {}
            local_namespace["X_train"] = X_train
            local_namespace["y_train"] = y_train
            local_namespace["X_test"] = X_test
            exec(
                f"prediction = {self.file_name}.generate_model(X_train, y_train, X_test)",
                globals(),local_namespace)
            globals().update(local_namespace)
            prediction=local_namespace["prediction"]

            if self.debug:
                print(prediction)


        if self.basic_task_types_1 == "natural language processing" or self.basic_task_types_1 == "computer vision":

            variable_dimensions_arguments_string=self.variable_dimensions_arguments_string_generator()

            if self.output_types=="integer representation of class labels":
                simulated_input, _,_=self.simulated_data_generation(variable_dimensions_arguments_string)
            else:
                simulated_input, _=self.simulated_data_generation(variable_dimensions_arguments_string)

            self.load_file()

            exec(
                f"model_raw = {self.file_name}.generate_model({variable_dimensions_arguments_string})",
                globals())

            try:
                torch.randn(1)
            except:
                import torch

            assert isinstance(model_raw, torch.nn.Module), f"model_raw is an instance of {type(model_raw)} but should be an instance of (child class of) torch.nn.Module"

            model = model_raw.eval()

            batch_outputs = []
            with torch.no_grad():
                output = model.forward(*simulated_input)

            batch_outputs.append(output)


            if self.debug:
                print(batch_outputs)
                print(type(batch_outputs))

            self.DL_data_test(batch_outputs, ["output"])

        print("A unit test for modeling is passed!")

    def data_preparation(self):

        if self.basic_task_types_1 == "traditional algorithms":
            self.load_file()
            exec(
                f"X_train, X_test, y_train, y_test= {self.file_name}.process_data()",
                globals())

            self.tabular_data_test(X_train, 2, "X_train")
            self.tabular_data_test(X_test, 2, "X_test")

            if self.basic_task_types_2 == "single-output regression" :
                self.tabular_data_test(y_train, 1, "y_train")
                self.tabular_data_test(y_test, 1, "y_test")
            elif self.basic_task_types_2 == "multi-output regression" :
                self.tabular_data_test(y_train, 2, "y_train")
                self.tabular_data_test(y_test, 2, "y_test")
            elif self.basic_task_types_2 == "binary classification" :
                if self.output_types == "integer representation of class labels" :
                    self.tabular_data_test(y_train, 1, "y_train")
                    self.tabular_data_test(y_test, 1, "y_test")
                elif self.output_types == "probability representation of class labels" :
                    self.tabular_data_test(y_train, 2, "y_train")
                    self.tabular_data_test(y_test, 2, "y_test")
            elif self.basic_task_types_2 == "multi-class classification" :
                if self.output_types == "integer representation of class labels" :
                    self.tabular_data_test(y_train, 1, "y_train")
                    self.tabular_data_test(y_test, 1, "y_test")
                elif self.output_types == "probability representation of class labels" :
                    self.tabular_data_test(y_train, 2, "y_train")
                    self.tabular_data_test(y_test, 2, "y_test")
            elif self.basic_task_types_2 == "multi-label classification" :
                if self.output_types == "integer representation of class labels" :
                    self.tabular_data_test(y_train, 2, "y_train")
                    self.tabular_data_test(y_test, 2, "y_test")
                elif self.output_types == "probability representation of class labels" :
                    self.tabular_data_test(y_train, 2, "y_train")
                    self.tabular_data_test(y_test, 2, "y_test")

            assert isclose(len(y_test) / (len(y_train) + len(y_test)),self.train_val_ratio,rel_tol=self.tolerance), f"y_test has {len(y_test) / (len(y_train) + len(y_test))*100}% of the data but should have {self.train_val_ratio*100}% of the data"

        if self.basic_task_types_1 == "natural language processing" or self.basic_task_types_1 == "computer vision":
            variable_dimensions_arguments_string=self.variable_dimensions_arguments_string_generator()
            self.load_file()

            exec(
                f"train_loader, val_loader  = {self.file_name}.generate_dataloader({variable_dimensions_arguments_string})",
                globals())

            try:
                torch.randn(1)
            except:
                import torch

            assert isinstance(train_loader, torch.utils.data.DataLoader), f"train_loader is not a {type(train_loader)} but a torch.utils.data.DataLoader"
            assert isinstance(val_loader, torch.utils.data.DataLoader), f"val_loader is not a {type(val_loader)} but a torch.utils.data.DataLoader"

            assert isclose(len(val_loader.dataset)/(len(train_loader.dataset) + len(val_loader.dataset)),self.train_val_ratio,rel_tol=self.tolerance), f"val_loader has {len(val_loader.dataset)/(len(train_loader.dataset) + len(val_loader.dataset))*100}% of the data but should have {self.train_val_ratio*100}% of the data"


            for train_tensor_group in train_loader:
                self.DL_data_test(train_tensor_group, ["input","output"])
                break

            for val_tensor_group in val_loader:
                self.DL_data_test(val_tensor_group, ["input","output"])
                break

        print("A unit test for data preparation is passed!")

    def post_processing(self):

        if self.basic_task_types_1 == "traditional algorithms":
            _, predicted_output=self.simulated_data_generation()
            _, ground_truth_output=self.simulated_data_generation()

            self.load_file()

            local_namespace = {}
            local_namespace["predicted_output"] = predicted_output
            local_namespace["ground_truth_output"] = ground_truth_output

            exec(
                f"score = {self.file_name}.generate_evaluation(ground_truth_output, predicted_output)",
                globals(),local_namespace)
            globals().update(local_namespace)

        if self.basic_task_types_1 == "natural language processing" or self.basic_task_types_1 == "computer vision":
            variable_dimensions_arguments_string=self.variable_dimensions_arguments_string_generator()

            if self.output_types=="integer representation of class labels":
                _,simulated_actual_output,simulated_predicted_output=self.simulated_data_generation(variable_dimensions_arguments_string)
            else:
                _, simulated_predicted_output=self.simulated_data_generation(variable_dimensions_arguments_string)
                _, simulated_actual_output=self.simulated_data_generation(variable_dimensions_arguments_string)

            self.load_file()

            evaluation_arguments_string= ""
            local_namespace = {}
            for tensor_idx_1,tensor_1 in enumerate(simulated_predicted_output):
                local_namespace["tensor_1"] = tensor_1
                exec(f"predicted_tensor_{tensor_idx_1}=tensor_1" , globals(),local_namespace)
                globals().update(local_namespace)
                evaluation_arguments_string+= f"predicted_tensor_{tensor_idx_1},"
            for tensor_idx_2,tensor_2 in enumerate(simulated_actual_output):
                local_namespace["tensor_2"] = tensor_2
                exec(f"actual_tensor_{tensor_idx_2}=tensor_2" , globals(),local_namespace)
                globals().update(local_namespace)
                evaluation_arguments_string+= f"actual_tensor_{tensor_idx_2},"

            exec(
                f"score = {self.file_name}.generate_evaluation({evaluation_arguments_string})",
                globals(),local_namespace)

            globals().update(local_namespace)

        score=local_namespace["score"]
        assert isinstance(score, float), f"The returned score is an instance of {type(score)} but should be an instance of float"

        if self.debug:
            print("the evaluation sore is",score)

        print("A unit test for post processing is passed!")