from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import ResponseSchema
from langchain.output_parsers import StructuredOutputParser

import os
import pandas as pd
import json
import re
from pandas.api.types import is_integer_dtype, is_float_dtype, is_string_dtype

knowledge = {
    "adult": "salary above 50K based on demographic features",
    "bank": "term deposit subscription based on client features",
    "blood": "donation based on donor features",
    "car": "car acceptability based on car features",
    "communities": "crime rate based on community features",
    "credit-g": "credit approval based on client features",
    "diabetes": "diabetes based on patient features",
    "heart": "heart disease based on patient features",
    "myocardial": "myocardial infarction based on patient features",
    "california_housing": "house price based on house features",
    "NHANES": "health status based on patient features",
    "cultivars": "soybean cultivar based on plant features",
}
def parser_answer(answer, info, cat_categories, feature_name, num_cnt):
    # Extract dict-like strings using regular expressions
    dict_strings = re.findall(r"\{[^{}]*\}", answer)

    # Convert dict-like strings to actual dictionaries
    dicts = [json.loads(ds) for ds in dict_strings]
    answer_dict = dicts[0]
    
    # check
    N_cols = info['N_cols']
    C_cols = info['C_cols']
    target = info['target']
    # check each target column
    assert set(answer_dict.keys()) == set(cat_categories[target]), f"Missing target classes in the response"

    #check each target class
    for target_class in cat_categories[target]:
        feature_values = answer_dict[target_class]
        assert len(feature_values) > 0, f"Feature values for target class {target_class} should not be empty"
        assert isinstance(feature_values, list), f"Feature values for target class {target_class} should be a list"
        if feature_name in N_cols:
            assert len(feature_values) == num_cnt, f"Feature values for target class {target_class} should contain {num_cnt} values"
            assert all(isinstance(x, int) or isinstance(x,float) for x in feature_values), f"Feature values for target class {target_class} should be integers or floats"
            

        if feature_name in C_cols:
            assert set(feature_values).issubset(set(cat_categories[feature_name])), f"Feature values for target class {target_class} has invalid categories: {feature_values}"
    return answer_dict

def langchain_templates_oracle_feature(df_empty,info, dataset, cat_condidates,feature_name, num_cnt):

    assert dataset in knowledge, f"dataset {dataset} not supported"
    response_schemas = []
    target = info['target']
    task = info['task']
    N_cols = info['N_cols']
    C_cols = info['C_cols']
    explantation = info['explantation']
    if is_integer_dtype(df_empty[feature_name].dtype):
        col_type = "integers"
    elif is_float_dtype(df_empty[feature_name].dtype):
        col_type = "floats"
    elif is_string_dtype(df_empty[feature_name].dtype):
        col_type = "strings"
    else:
        print(f"column {feature_name} has unknown type")
        exit()

    for target_class in cat_condidates[target]:
        resp = ResponseSchema(
            type=f"[list of {col_type}]",
            name=target_class,
            description=f"possible values of feature {feature_name} for target class {target_class}"
        )
        response_schemas.append(resp)
    output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
    format_instructions = output_parser.get_format_instructions()
    feature_desc = f"{feature_name}, {explantation[feature_name]}"

    if feature_name in C_cols:
        if len(cat_condidates[feature_name]) <= 20:
            feature_desc += f"(categorical variable with categories: {cat_condidates[feature_name]})"
        else:
            feature_desc += f"(categorical variable with categories: [{cat_condidates[feature_name][0]}, {cat_condidates[feature_name][1]}, ..., {cat_condidates[feature_name][-1]}]"
    elif feature_name in N_cols:
        feature_desc += "(numeric variable, you should use your prior knowledge to determine the appropriate ranges of values)"
    else:
        print(f"column {feature_name} not found in the dataset")
        exit()
    
    generator_template = f"""You are an expert in analyzing relationships between features and target variables.

Given a feature description and a task, your goal is to analyze how the feature relates to the target and then generate a dictionary with specific details.

Task:{{task}}

Feature: {{feature_desc}}

First conduct a thorough analysis of the relationship between the feature and task using your prior knowledge.
Then based on this analysis, create a dictionary with the following format:
{{format_instructions}}
The {feature_name} values should be presented as {col_type} in lists.
"""
    if feature_name in N_cols:
        generator_template += f'Make sure to include {num_cnt} typical {feature_name} values for each target class in {cat_condidates[target]}.'
    elif feature_name in C_cols:
        generator_template += f'For the kind of {feature_name} values that are hard to predict, it is not necessary to include them in the dictionary. But make sure list of each target class is not empty.'
    prompt = ChatPromptTemplate.from_template(template=generator_template)
    
    return prompt, task, feature_desc, format_instructions