import pandas as pd
from collections import defaultdict


def _normalize_params(params):
    new_params = {}
    for k in params.keys():
        if k.startswith('param-'):
            new_params[k] = params[k][0] if (type(params[k]) == list and len(params[k]) == 1) else params[k]
        else:
            new_params['param-' + k] = params[k][0] if type(params[k]) == list and len(params[k]) == 1 else params[k]
    return '+'.join(sorted(f"({k},{v})" for k, v in new_params.items()))


def _normalize_param_names(params):
    return ','.join(sorted(list(params.keys())))


class APIDataLoader:

    def __init__(self, apis_map, filepath):
        self.apis_map = apis_map
        self.api_result = defaultdict(dict)
        self.raw_data = self._load_raw_data(filepath)
        self._add_raw_data('/your_file')
        api_name_map = self._build_name_mapping()
        self._build_api_result(api_name_map)

    def _load_raw_data(self, filepath):
        def flatten(nest_list: list):
            return [j for i in nest_list for j in i]
        if filepath.endswith('xlsx'):
            df = pd.read_excel(filepath, index_col=0).reset_index(drop=True)
            df['call_apis'] = df['call_apis'].apply(eval)
            call_apis = flatten(df['call_apis'].tolist())
            return call_apis

        elif filepath.endswith('json'):
            with open(filepath, 'r') as f:
                return [eval(line) for line in f]
        else:
            raise ValueError

    def _add_raw_data(self, file):
        df = pd.read_excel(file, index_col=0).reset_index(drop=True)
        df['call_chains'] = df['call_chains'].apply(eval)
        res = []

        for row in df['call_chains']:
            for call_api in row:
                api_ = call_api['api_name'] if call_api['api_name'].startswith('api-') else 'api-' + call_api['api_name']
                try:
                    new_input = {}
                    for k, v in call_api['input'].items():
                        k = k if k.startswith('param-') else 'param-' + k
                        new_input[self.apis_map[api_]['param_mapping']['input'][k.replace('param-', '')]] = v

                    new_output = {}
                    for k, v in call_api['output']['data'].items():
                        k = k if k.startswith('param-') else 'param-' + k
                        new_output[self.apis_map[api_]['param_mapping']['output'][k.replace('param-', '')]] = v
                    res.append([self.apis_map[api_]['ori_name'], [new_input, new_output]])
                except Exception as e:
                    print(api_, e)
                res.append([call_api['api_name'], [call_api['input'], call_api['output']['data']]])
        self.raw_data.extend(res)

    def _build_name_mapping(self):
        api_name_map = defaultdict(dict)
        for k in self.apis_map.keys():
            input_params = self.apis_map[k]['param_mapping']['input']
            input_param_map = {}
            output_param_map = {}
            for input_k in input_params.keys():
                input_param_map[input_params[input_k]] = input_k
            output_params = self.apis_map[k]['param_mapping']['input']
            for output_k in output_params.keys():
                output_param_map[output_params[output_k]] = output_k
            mapping = {'name': k,
                       'input_param_map': input_param_map,
                       'output_param_map': output_param_map}
            params_ = _normalize_param_names(input_param_map)
            if params_ not in api_name_map[self.apis_map[k]['ori_name']]:
                api_name_map[self.apis_map[k]['ori_name']][params_] = {k: mapping}
            else:
                api_name_map[self.apis_map[k]['ori_name']][params_].update( {k: mapping})
        return api_name_map

    def _build_api_result(self, api_name_map):
        for api_result in self.raw_data:
            api_name = api_result[0]
            ori_input = api_result[1][0]
            ori_output = api_result[1][1]
            input_params = _normalize_param_names(ori_input)
            if (api_name not in api_name_map) or (len(api_name_map[api_name]) == 0):
                self.api_result[api_name][_normalize_params(ori_input)] = ori_output
                continue
            if input_params not in api_name_map[api_name]:
                input_params_list = list(api_name_map[api_name].keys())
            else:
                input_params_list = [input_params]
            for input_params in input_params_list:
                for k, v in api_name_map[api_name][input_params].items():
                    new_input = {}
                    for kk, vv in ori_input.items():
                        if kk in v['input_param_map']:
                            new_input['param-' + v['input_param_map'][kk]] = vv
                        else:
                            new_input['param-' + kk] = vv
                    new_output = []

                    if type(ori_output) == dict:
                        ori_output = [ori_output]
                    for ori_ in ori_output:
                        new_ = {}
                        if type(ori_) == str:
                            new_output.append(new_)
                            continue
                        for kk, vv in ori_.items():
                            if kk in v['output_param_map']:
                                new_['param-' + v['output_param_map'][kk]] = vv
                            else:
                                new_['param-' + kk] = vv
                        new_output.append(new_)

                    self.api_result[v['name']][_normalize_params(new_input)] = new_output
        return


# ---------------------------
# simulator/base.py
# ---------------------------
class BaseAPISimulator:

    def execute(self, api_name: str, params: dict) -> dict:
        raise NotImplementedError


# ---------------------------
# simulator/config_based.py
# ---------------------------
class ConfigAPISimulator(BaseAPISimulator):

    def __init__(self, api_result):
        self.api_result = api_result

    def execute(self, api_name: str, params: dict) -> dict:
        param_key = _normalize_params(params)
        api_name_clean = api_name

        if api_name_clean in self.api_result:
            if param_key in self.api_result[api_name_clean]:
                return self._build_success_result(self.api_result[api_name_clean][param_key])
            return self._fallback_result(api_name_clean)
        return self._build_error_result()

    def _fallback_result(self, api_name):
        if len(self.api_result[api_name]) > 0:
            return {"status": "success", "message": list(self.api_result[api_name].values())[0], "type": "mock_data"}

        return {"status": "success", "message": 'call success', "type": "mock_data"}

    @staticmethod
    def _build_success_result(data):
        return {"status": "success", "message": data, "type": "success"}

    @staticmethod
    def _build_error_result():
        return {"status": "error", "message": "API not found", "type": "fail"}


class APISimulator:
    def __init__(self, apis, file_):
        data_loader = APIDataLoader(apis, file_)
        self.simulator = ConfigAPISimulator(data_loader.api_result)

    def run(self, api_name, params):
        return self.simulator.execute(api_name=api_name, params=params)


if __name__ == "__main__":
    from function_call_agent.tools import Graph_Search
    graph_search = Graph_Search(search_type='alpha_beta', graph_degree=None)
    apis = graph_search.apis
    sl = APISimulator(apis=apis, file_='/Users/')
    print(sl.run("api-add_exercise_to_program_1",
                 {"param-exercise_id_3": "1", 'param-patient_id_10': 123, 'param-reps_1': '10',
                  "param-rest_time_1": "15", 'param-sets_completed_2': 3}))


