from typing import ClassVar
import streamlit as st
from stqdm import stqdm
from dsl import spec, SpecTracker
from dsl.grammar import (
    create_variable as V,
    Expectation as E,
    RETURN_VARIABLE as r
)
from dsl.grammar import Specification
from dsl.visualization import add_vega_chart, create_viz_spec
from dsl.tests.datasets import get_dataset_names, get_dataset_attr_list, get_dataset_attr_dict, Dataset, get_dataset
from dsl.tests.models import get_models, get_model, ModelBasedTest, InvalidTargetError, InvalidTrainError, _view_maintenance
import pdb

class CustomDatasetModel(ModelBasedTest):
    def __init__(self, dataset, output_var, input_vars, model):
        self.output_var = output_var
        self.input_vars = input_vars
        attrs = self.input_vars
        self.dataset: Dataset = Dataset(dataset.train, dataset.test, attrs, target=output_var, inputs=input_vars)

        self.model = model(self.dataset)

    def update_spec(self, new_spec):
        if not isinstance(new_spec, Specification):
            raise ValueError("Input provided was not parsed into a specification")

        @spec(new_spec, include_confidence=True)
        def decision_func(**kwargs):
            x = kwargs.get("x")
            # prediction = self.model.model.predict(x.reshape(1, -1)) # TODO??
            prediction = self.model.predict(x.reshape(1, -1))
            return prediction[0]
        self.f_x = decision_func

    @property
    def data(self):
        test_data = []
        for _, d in self.dataset.test.iterrows():
            x = d.loc[self.input_vars].to_numpy()
            data_dict = {
                "x": x
            }
            for ind, attr in enumerate(self.dataset.attributes_dict.keys()):
                data_dict[attr] = d[attr]
            test_data.append(data_dict)
        return test_data


def provide_model_eval_interface(model_obj):
    model_obj.update_spec(eval(spec_code))  # TODO DANGEROUS
    if st.button("Run spec analysis"):
        with st.spinner("Running analysis"):
            model_obj.run_eval_loop(progress_bar=stqdm)
        with st.spinner("Generating spec chart"):
            data_values = model_obj.get_tabular_rep()
            viz_spec = create_viz_spec(data_values)
            add_vega_chart(viz_spec)

    else:
        st.markdown("### Press the button to run")

if __name__ == "__main__":
    st.set_page_config(layout='wide')
    # datasets
    datasets = get_dataset_names()
    selected_dataset_name = st.selectbox("Select dataset", datasets)
    selected_dataset = get_dataset(selected_dataset_name)

    dataset_attr_dict = get_dataset_attr_dict(selected_dataset_name)
    st.markdown("## Dataset attributes ")
    st.write(get_dataset_attr_dict(selected_dataset_name))
    attr_list = selected_dataset.attributes
    #st.write(attr_list)

    selected_model_name = st.selectbox("Select ML model", get_models())
    selected_model = get_model(selected_model_name)

    if selected_model_name == _view_maintenance:
        output_var = selected_dataset.target
        input_vars = attr_list
    else:
        col1, col2 = st.beta_columns(2)
        output_var = col1.selectbox("Choose output variable", attr_list, index=attr_list.index(selected_dataset.target))
        input_vars = col2.multiselect("Choose input variables", attr_list, default=selected_dataset.inputs)
        if output_var in input_vars:
            st.markdown("***NOTE: input includes output***")
    
    st.markdown(f"These attributes can be used as variables in the spec: {', '.join(input_vars)} ")
    st.markdown("Note: Return variable is **r**,"
                " and other variables must be input in quotes. 'x' refers to the full input")

    test_specs = {
        "Adult Income": 'E(r, given=(V("sex_Male") == 1)) / E(r, given=(V("sex_Female") == 1)) < 1.2',
        "Boston Housing Prices": 'E(r, given=(V("lstat") > 12)) / E(r, given=(V("lstat") < 12)) > 0.6',
        "Rate My Professors": 'E(r, given=(V("gender_male") == 1) & (V("male_dominated_department") == 1)) / E(r, given=(V("gender_female") == 1) & (V("male_dominated_department") == 1)) < 1.2',
        "Compas": '(E(V("two_year_recid"),given=(V("score_text_Low") == 0) & (V("race_African-American") == 1))) / (E(V("two_year_recid"),given=(V("score_text_Low") == 0) & (V("race_Caucasian") == 1))) > 1.0'
    }
    initial_spec = test_specs.get(selected_dataset_name, test_specs.get("Boston Housing Prices"))
    spec_code = st.text_input("Spec:", initial_spec)

    try:
        with st.spinner("Training model..."):
            model_obj = CustomDatasetModel(selected_dataset, output_var, input_vars, selected_model)
        provide_model_eval_interface(model_obj)
    except InvalidTargetError:
        st.markdown("Oops! Looks like tthe chosen output variable isn't compatible with the selected model. Select another configuration")
    except InvalidTrainError:
        st.markdown("Oops! Looks like tthe chosen training data isn't compatible with the selected model. Select another configuration")
