from data import CausalDataset, MoralDataset, AbstractDataset, Example, FactorUtils
from sklearn.neural_network import MLPClassifier, MLPRegressor
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import numpy as np

from tqdm import tqdm

from typing import List, Tuple

class ExpertReasoningEngine:
    factors: List[str]
    tag_to_feat_map: List[str]
    labels: List[str]
    X: np.array
    y_scores: np.array
    y: np.array

    def __init__(self, dataset: AbstractDataset):
        self.dataset = dataset
        self.model = MLPClassifier(solver='lbfgs', alpha=1e-5,
                                   hidden_layer_sizes=(64, 64), random_state=1)

    def generate_label_vocab(self) -> List[str]:
        tag_to_feat_map = []
        for factor in self.factors:
            tag_to_feat_map.extend(eval(f"FactorUtils.{factor}_answers"))

        return tag_to_feat_map

    def featurize_example(self, ex: Example) -> List[int]:
        feat = [0] * len(self.tag_to_feat_map)
        for sent in ex.annotated_sentences:
            # one hot encoding of sent.annotation.value
            feat[self.tag_to_feat_map.index(
                eval(f"FactorUtils.{sent.annotation.factor}_answers_map")[sent.annotation.value])] = 1

        return feat

    def process_dataset(self) -> Tuple[np.array, np.array, np.array]:
        X, y_scores, y = [], [], []
        ex: Example
        for ex in tqdm(self.dataset.examples):
            X.append(self.featurize_example(ex))
            y_scores.append(ex.answer_dist[0])
            y.append(self.labels.index(ex.answer))

        return np.array(X), np.array(y_scores), np.array(y)

    def train(self):
        self.model.fit(self.X, self.y)
        print("Training finished")
        y_hat = self.model.predict(self.X)
        print("Training accuracy:", accuracy_score(self.y, y_hat))

        # compute unambiguous example accuracy
        # unambiguous_X, unambiguous_y = [], []
        # for i in range(len(self.y)):
        #     if self.y_scores[i] >= 0.6 or self.y_scores[i] <= 0.4:
        #         unambiguous_X.append(self.X[i])
        #         unambiguous_y.append(self.y[i])
        #
        # unambiguous_X = np.array(unambiguous_X)
        # unambiguous_y = np.array(unambiguous_y)
        #
        # self.model.fit(unambiguous_X, unambiguous_y)
        # y_hat = self.model.predict(unambiguous_X)
        #
        # print(accuracy_score(unambiguous_y, y_hat))

    def predict(self, ex: Example) -> List[float]:
        feat = self.featurize_example(ex)
        choice_scores = self.model.predict_proba([feat])
        return [choice_scores[0][0], choice_scores[0][1]]

class CausalReasoningEngine(ExpertReasoningEngine):
    def __init__(self, dataset: CausalDataset):
        super().__init__(dataset)
        self.factors = ['causal_structure', 'agent_awareness', 'action_omission', 'event_normality', 'time', 'norm_type']
        self.tag_to_feat_map = self.generate_label_vocab()
        self.labels = ['Yes', 'No']

        self.X, self.y_scores, self.y = self.process_dataset()

class MoralReasoningEngine(ExpertReasoningEngine):
    def __init__(self, dataset: MoralDataset):
        super().__init__(dataset)
        self.factors = ['locus_of_intervention', 'beneficiary', 'personal_force', 'evitability', 'causal_role']
        self.tag_to_feat_map = self.generate_label_vocab()
        self.labels = ['Yes', 'No']

        self.X, self.y_scores, self.y = self.process_dataset()

if __name__ == '__main__':
    cre = CausalReasoningEngine(CausalDataset())
    cre.train()  # 0.7569
    cd = CausalDataset()
    print(cre.predict(cd[0]), cd[0].answer)

    mre = MoralReasoningEngine(MoralDataset())
    mre.train()  # 0.8064
    md = MoralDataset()
    print(mre.predict(md[0]), md[0].answer)
    print(mre.predict(md[13]), md[13].answer)