from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import ResponseSchema
from langchain.output_parsers import StructuredOutputParser

import os
import pandas as pd
import json
import re
from pandas.api.types import is_integer_dtype, is_float_dtype, is_string_dtype

knowledge = {
    "adult": "salary above 50K based on demographic features",
    "bank": "term deposit subscription based on client features",
    "blood": "donation based on donor features",
    "car": "car acceptability based on car features",
    "communities": "crime rate based on community features",
    "credit-g": "credit approval based on client features",
    "diabetes": "diabetes based on patient features",
    "heart": "heart disease based on patient features",
    "myocardial": "myocardial infarction based on patient features",
    "california_housing": "house price based on house features",
    "NHANES": "health status based on patient features",
    "cultivars": "soybean cultivar based on plant features",
}
def parser_weights_answer(answer, info, cat_categories):
    # Extract dict-like strings using regular expressions
    dict_strings = re.findall(r"\{[^{}]*\}", answer)

    # Convert dict-like strings to actual dictionaries
    dicts = [json.loads(ds) for ds in dict_strings]
    weights = dicts[0]
    
    print(weights)
    
    # check
    N_cols = info['N_cols']
    C_cols = info['C_cols']
    for col in N_cols+C_cols:
        assert isinstance(weights[col], float), f"weight of {col} should be float"
        assert 0 <= weights[col] <= 1, f"weight of {col} should be between 0 and 1"
    return weights

def langchain_templates_weights(info, dataset, cat_condidates):

    assert dataset in knowledge, f"dataset {dataset} not supported"
    response_schemas = []
    N_cols = info['N_cols']
    C_cols = info['C_cols']
    explantation = info['explantation']
    print(N_cols)
    print(C_cols)
    task_desc = info['task']
    print(task_desc)
    feature_desc = ""
    for col in N_cols+C_cols:
        feature_desc += f"{col}: {explantation[col]}\n"
    print(feature_desc)
    for col in N_cols+C_cols:
        resp = ResponseSchema(
            type="float",
            name=col,
            description = ""
            # description=f"importance of {col} in predicting {target}, a float value between 0 and 1",
        )
        response_schemas.append(resp)
    output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
    format_instructions = output_parser.get_format_instructions()
    generator_template = f"""You are an expert in analyzing relationships between features and target variables.
I will provide you with the task description and feature descriptions of dataset. Your goal is to analyze the importance of each feature in predicting the target variable based on the relationship of features and target.

Task:{{task_desc}}

Feature: {{feature_desc}}

{{format_instructions}}

Please provide the importance of each feature in predicting the target variable. The importance of each feature should be a float value between 0 and 1.
"""
    prompt = ChatPromptTemplate.from_template(template=generator_template)

    return prompt, format_instructions, task_desc, feature_desc