"""Characterization of the transitions between sets of states.

A transition predicate is a function that takes as input a vector
corresponding to the concatenation of a state and action and maps it
to number. If the number is greater than zero, then the predicate is true.
Otherwise the predicate is said to be false.
"""
from swmpo.transition import Transition
from swmpo.transition import get_vector
from collections import defaultdict
from itertools import product
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.neural_network import MLPClassifier
from sklearn.tree import DecisionTreeClassifier
import random
from dataclasses import dataclass


@dataclass
class TransitionPredicates:
    transition_predicates: list[list[MLPClassifier | None]]
    transition_histogram: list[list[int]]


def get_transition_predicates(
    partition: list[list[Transition]],
    predicate_hyperparameters: dict,
    seed: str,
) -> TransitionPredicates:
    """Returns the adjacency matrix containing the
    predicates and histogram of transitions between states."""
    _random = random.Random(seed)
    state_indices = list(range(len(partition)))
    transition_predicate_matrix = defaultdict(dict)
    transition_histogram = [[0 for _ in state_indices] for _ in state_indices]

    # SPEED: make source state existence look-up fast for
    # predicate dataset construction. Otherwise checking if a
    # "next state" is the "source state" of any transition is
    # prohibitively slow.
    source_states = [
        set([
            tuple(transition.source_state.tolist())
            for transition in subset
        ])
        for subset in partition
    ]
    # /SPEED

    for i, j in product(state_indices, state_indices):
        # Partition the set of next states in node i into two sets:
        # - All the next states in node i that are a source state
        # in node j. These are states which say "yes, transition from
        # node i to node j".
        # - All the next states in node i that are not a source state
        # in node j. These are states which say "no, do not transition from
        # node i to node j".
        positive = list()
        negative = list()
        for transition in partition[i]:
            next_state = tuple(transition.next_state.tolist())
            is_next_source_in_j = next_state in source_states[j]
            vector = get_vector(transition)
            if is_next_source_in_j:
                positive.append(vector.detach().numpy())
            else:
                negative.append(vector.detach().numpy())

        # Make the dataset balanced
        #min_size = min(len(positive), len(negative))
        #_random.shuffle(positive)
        #_random.shuffle(negative)
        #positive = positive[:min_size]
        #negative = negative[:min_size]

        # Then, turn these two sets into classification problem.
        X = positive + negative
        Y = [1 for _ in positive] + [0 for _ in negative]

        # Synthesize a transition predicate
        if len(X) == 0:
            transition_predicate = None
        else:
            transition_predicate = Pipeline([
                ("normalizer", StandardScaler()),
                ("classifier", MLPClassifier(
                    random_state=_random.getrandbits(32),
                    **predicate_hyperparameters,
                )),
                #("classifier", DecisionTreeClassifier(
                #    random_state=_random.getrandbits(32),
                #)),
            ])
            transition_predicate.fit(X, Y)

        # Store the transition predicate
        transition_predicate_matrix[i][j] = transition_predicate

        # Store the non-conditional transition probability
        transition_histogram[i][j] = len(positive)

    # Turn matrix into lists to adhere to API
    transition_predicates = [
        [
            transition_predicate_matrix[i][j]
            for j in transition_predicate_matrix[i].keys()
        ]
        for i in transition_predicate_matrix.keys()
    ]
    data = TransitionPredicates(
        transition_predicates=transition_predicates,
        transition_histogram=transition_histogram,
    )
    return data
