from typing import List, Tuple, Callable, Optional, Dict, Union

import numpy as np
import pandas as pd
import torch
from numpy.typing import NDArray

from core.abstract_mdp import AbstractMDP
from core.cluster import Clusterer, KMeans, DBSCAN
from core.data import TransitionData, AbstractState, FactorValue, Factor, AbstractStateDict
from core.msa import MSAFlat, MSAObject
from core.tests import knn_ind_test, knn_accuracy


def partition(data: pd.DataFrame, *column_names) -> List:
    partitions = list()
    for _, cluster in data.groupby(list(column_names), as_index=False):
        partitions.append(cluster)
    return partitions


def make_state(factor_values: List[FactorValue],
               state_values: torch.Tensor,
               initiation_vector: Tuple[int, ...],
               action_counts: np.ndarray,
               info: Optional[List[Dict]] = None) -> AbstractState:
    state = AbstractState(0, initiation_vector, factor_values, state_values, action_counts, None)
    if info is not None:
        state.info = info
    return state


def add_state(states: List[AbstractState],
              new_state: AbstractState) -> AbstractState:
    for state in states:
        if state == new_state:
            state.data = torch.cat([state.data, new_state.data], dim=0)
            state.action_counts += new_state.action_counts

            if not new_state.info:
                return state
            if not state.info:
                state.info = list()
            state.info.extend(new_state.info)
            return state
    new_state.id = len(states)
    states.append(new_state)
    return new_state


def find_or_make_fv(fv, factor, factor_ids, factor_values, threshold=(0.35, 0.65), n_reps=5, max_sample=100):
    f_id = -1
    for old_f_id in factor_ids[factor]:
        old_fv = factor_values[factor][old_f_id]
        n_sample = min(max_sample, len(fv), len(old_fv))
        a = torch.randint(0, len(fv), (n_reps, n_sample))
        b = torch.randint(0, len(old_fv), (n_reps, n_sample))
        x = fv[a]
        y = old_fv[b]
        acc = knn_accuracy(x, y).mean()
        # two distributions overlap if knn couldn't differentiate
        if threshold[0] < acc < threshold[1]:
            f_id = old_f_id
            break

    # create a new factor value
    if f_id == -1:
        f_id = len(factor_ids[factor])
        factor_ids[factor].append(f_id)
        factor_values[factor][f_id] = fv
    # merge with previous factor value
    else:
        factor_values[factor][f_id] = torch.cat([fv, old_fv], dim=0)
    return f_id


def compute_error(samples: NDArray[np.object_],
                  all_factors: List[Factor],
                  n_tests: int = 50,
                  effect_factors: Optional[List[Factor]] = None) -> Tuple[float, float]:
    s, s_prime, reward = _get_states_from_ndarray(samples)
    s = torch.tensor(s, dtype=torch.float)
    s_prime = torch.tensor(s_prime, dtype=torch.float)

    all_vars = []
    for f in all_factors:
        all_vars.extend(f.variables)
    s = s[..., all_vars]

    if effect_factors is not None:
        variables = []
        for f in effect_factors:
            variables.extend(f.variables)
        s_prime = s_prime[..., variables]
    else:
        s_prime = s_prime[..., all_vars]

    transition_error = knn_ind_test(s, s_prime, n_tests=n_tests)
    reward_error = float(np.var(reward))
    return transition_error, reward_error


def compute_error_msa(samples: NDArray[np.object_], msa: Union[MSAFlat, MSAObject],
                      k: int = 10) -> Tuple[float, float]:
    s, s_prime, reward = _get_states_from_ndarray(samples)
    s = torch.tensor(s, dtype=torch.float)
    s_prime = torch.tensor(s_prime, dtype=torch.float)
    s_neg1 = s_prime[torch.randint(0, len(s_prime), (k*len(s_prime),))]
    logN = torch.log(torch.tensor(k+1))
    mi_ground = torch.relu(logN - msa.density_loss(s, s_prime, s_neg1)).item()
    reward_error = float(np.var(reward))
    return mi_ground, reward_error


def compute_mdp_error(abstract_states: AbstractStateDict,
                      all_factors: List[Factor],
                      n_tests: int = 50,
                      changed_factors: Optional[Dict[int, List[Factor]]] = None,
                      msa: Optional[Union[MSAFlat, MSAObject]] = None,
                      ) -> Dict[Tuple[Tuple[int, int], ...], Tuple[float, float]]:
    state_errors = {}
    state_counts = {}

    for s_key in abstract_states.get_keys():
        action_samples = abstract_states.get_transitions(s_key)
        n_valid_actions = len(action_samples)
        n_samples = sum(len(x) for x in action_samples.values())

        if n_samples == 0:
            state_errors[s_key] = (0.0, 0.0)
            state_counts[s_key] = 0
            continue

        t_error = np.zeros(n_valid_actions, dtype=float)
        r_error = np.zeros(n_valid_actions, dtype=float)
        action_counts = np.zeros(n_valid_actions, dtype=int)

        for i, action in enumerate(action_samples):
            # NOTE: we don't check if this action is indeed executable.
            # Not being executable would mean something is wrong with the environment.
            # But it would be nice to add this as a test.
            effect_factors = None
            if changed_factors is not None:
                effect_factors = changed_factors[action]

            samples = action_samples[action]
            if len(samples) == 0:
                continue

            if msa is not None:
                te, re = compute_error_msa(samples, msa)
            else:
                te, re = compute_error(samples, all_factors, n_tests, effect_factors)

            t_error[i] = te
            r_error[i] = re
            action_counts[i] = len(samples)

        Z = action_counts.sum()
        assert Z != 0

        state_counts[s_key] = Z
        P = action_counts / Z
        state_errors[s_key] = ((t_error*P).sum(), (r_error*P).sum())
    Z = sum(state_counts.values())
    P = {s_key: count/Z for s_key, count in state_counts.items()}
    state_errors = {k: (v[0]*P[k], v[1]*P[k]) for k, v in state_errors.items()}
    state_errors = sorted(state_errors.items(), key=lambda x: (x[1][0], x[1][1]), reverse=True)
    state_errors = {k: v for k, v in state_errors}
    return state_errors


def refine(factor_key: Tuple[int, int],
           abstract_states: AbstractStateDict,
           factor: Tuple[int, Factor],
           clusterer: Clusterer,
           min_samples: int = 10) -> Union[AbstractStateDict, None]:
    idx, f_i = factor
    new_states = abstract_states.copy()
    fsamples, next_fsamples = new_states.get_factor_values(f_i.name, factor_key)

    if len(fsamples) < min_samples:
        # cannot refine if n_samples is less than min_samples
        return None

    fstates = np.stack(fsamples[:, 0].tolist())[..., f_i.variables]
    std = np.std(fstates, axis=0)
    if (std < 1e-5).all():
        # cannot refine if all variables are constant
        return None

    clusterer.fit(fstates)
    labels = clusterer._clusterer.labels_

    if len(np.unique(labels)) == 1:
        # the refinement is not successful as all samples
        # are assigned to the same cluster
        return None

    # the rest is all about updating relations in abstract_states struct

    # update factor -> ground mapping
    n_fv1 = (factor_key[0], factor_key[1]*2)
    n_fv2 = (factor_key[0], factor_key[1]*2+1)
    old_indices, next_old_indices = new_states._factor_to_ground[f_i.name][factor_key]
    new_states._factor_to_ground[f_i.name][n_fv1] = [[old_indices[i] for i in np.where(labels == 0)[0]], []]
    new_states._factor_to_ground[f_i.name][n_fv2] = [[old_indices[i] for i in np.where(labels == 1)[0]], []]

    # update factor -> ground mapping for next states
    if len(next_fsamples) > 0:
        next_fstates = np.stack(next_fsamples[:, 3].tolist())[..., f_i.variables]
        next_labels = clusterer.predict(next_fstates)
        new_states._factor_to_ground[f_i.name][n_fv1][1] = [next_old_indices[i] for i in np.where(next_labels == 0)[0]]
        new_states._factor_to_ground[f_i.name][n_fv2][1] = [next_old_indices[i] for i in np.where(next_labels == 1)[0]]
    del new_states._factor_to_ground[f_i.name][factor_key]

    # update abstract_state -> ground_indices and
    # ground_indices -> abstract_state mappings.
    for key in new_states.get_keys().copy():

        # update the state whose factor is refined
        if key[idx] == factor_key:
            new_key1 = key[:idx] + ((key[idx][0], key[idx][1]*2),) + key[idx+1:]
            new_key2 = key[:idx] + ((key[idx][0], key[idx][1]*2+1),) + key[idx+1:]
            samples, next_samples = new_states.get_ground_samples(key)
            old_indices, next_old_indices = new_states._abs_to_ground[key]

            # update state distributions wrt. clustering
            if len(samples) > 0:
                states = np.stack(samples[:, 0].tolist())[..., f_i.variables]
                labels = clusterer.predict(states)
                for i, x in enumerate(labels):
                    abs_idx = old_indices[i]
                    key_to_be_set = new_key1 if x == 0 else new_key2

                    if key_to_be_set not in new_states._abs_to_ground:
                        new_states._abs_to_ground[key_to_be_set] = [[], []]

                    new_states._abs_to_ground[key_to_be_set][0].append(abs_idx)
                    new_states._ground_to_abs[abs_idx, 0] = key_to_be_set

            # update next state distributions wrt. clustering
            if len(next_samples) > 0:
                next_states = np.stack(next_samples[:, 3].tolist())[..., f_i.variables]
                next_labels = clusterer.predict(next_states)
                for i, x in enumerate(next_labels):
                    abs_idx = next_old_indices[i]
                    key_to_be_set = new_key1 if x == 0 else new_key2

                    if key_to_be_set not in new_states._abs_to_ground:
                        new_states._abs_to_ground[key_to_be_set] = [[], []]

                    new_states._abs_to_ground[key_to_be_set][1].append(abs_idx)
                    new_states._ground_to_abs[abs_idx, 1] = key_to_be_set

            # update abstract_state, action -> ground mapping
            if key in new_states._abs_w_act_to_ground:
                for action in new_states._abs_w_act_to_ground[key]:
                    indices = new_states._abs_w_act_to_ground[key][action]
                    for i in indices:
                        x = new_states._ground_to_abs[i, 0]

                        if x not in new_states._abs_w_act_to_ground:
                            new_states._abs_w_act_to_ground[x] = {}

                        if action not in new_states._abs_w_act_to_ground[x]:
                            new_states._abs_w_act_to_ground[x][action] = []

                        new_states._abs_w_act_to_ground[x][action].append(i)

                del new_states._abs_w_act_to_ground[key]

            del new_states._abs_to_ground[key]

    return new_states


def build_model(abstract_states: AbstractStateDict,
                option_names: List[str],
                factors: List[Factor],
                **kwargs) -> Tuple[AbstractMDP, Dict[Tuple[Tuple[int, int], ...], AbstractState]]:
    states = dict()
    state_list = list()
    for s_key in abstract_states.get_keys():
        factor_values = _create_factor_values_from_key(s_key, factors)
        at_state, at_next_state = abstract_states.get_ground_samples(s_key)
        if len(at_state) != 0:
            init_vec = at_state[0, 6]
        else:
            # this can only happen if the ground samples are always
            # in the next state. e.g., goal states.
            init_vec = at_next_state[0, 6]

        info = at_state[:, 8].tolist() + at_next_state[:, 9].tolist()
        state = torch.tensor(np.stack(at_state[:, 0].tolist() + at_next_state[:, 3].tolist()), dtype=torch.float)
        action_counts = np.zeros(len(option_names), dtype=int)
        if len(at_state) != 0:
            uniques, counts = np.unique(np.stack(at_state[:, 1].tolist()), return_counts=True)
            action_counts[uniques] = counts
        s = make_state(factor_values, state, init_vec, action_counts, info)
        s.id = len(states)
        state_list.append(s)
        states[s_key] = s

    mdp = AbstractMDP(state_list, option_names, rounding_rule=0.05)

    # add transitions
    abstract_transitions = list()
    for s_key in states:
        state = states[s_key]
        transitions = abstract_states.get_abstract_transitions(s_key)
        for a in transitions:
            for ns_key, (prob, reward, steps) in transitions[a].items():
                next_state = states[ns_key]
                mdp.add_transition(state, a, next_state, prob, reward, steps)

    state_visualiser: Optional[Callable] = kwargs.get("state_visualiser", None)
    if state_visualiser:
        state_visualiser(state_list)
    transition_visualiser: Optional[Callable] = kwargs.get("transition_visualiser", None)
    if transition_visualiser:
        transition_visualiser(abstract_transitions, option_names)
    return mdp, states


def initialise_abstract_states(data: pd.DataFrame,
                               factors: List[Factor]) -> AbstractStateDict:
    init_splits = partition(data, "options_available")
    effect_splits = partition(data, "next_options_available")
    factor_ids = {f: [] for f in factors}
    factor_values = {f: {} for f in factors}
    key_names = [f.name for f in factors]
    key_values = np.zeros((len(data),), dtype=object)
    next_key_values = np.zeros((len(data),), dtype=object)
    init_key_map = {}
    for init in init_splits:
        key = ()
        init_vec = init.iloc[0]["options_available"]
        for f in factor_ids:
            fv = np.stack(init["state"].values)[..., f.variables]
            fv = torch.tensor(fv, dtype=torch.float)
            # compare the current factor values with old values, and
            # use the old factor value if they significantly overlap
            f_id = find_or_make_fv(fv, f, factor_ids, factor_values)
            key = key + ((f_id, 1),)
        init_key_map[init_vec] = key
        for i in init.index.to_numpy():
            key_values[i] = key

    # set next_key_values based on previously set init_vec->key mapping
    for effect in effect_splits:
        next_init_vec = effect.iloc[0]["next_options_available"]
        if next_init_vec in init_key_map:
            next_key = init_key_map[next_init_vec]
        else:
            # NOTE: this can only happen if this specific initiation
            # vector has not been seen in the start states. I think
            # the most probable case is the goal states, and we can
            # aggregate them initially into factor value -1 (default)
            next_key = ()
            for f in factor_ids:
                next_fv = np.stack(effect["next_state"].values)[..., f.variables]
                next_fv = torch.tensor(next_fv, dtype=torch.float)
                f_id = find_or_make_fv(next_fv, f, factor_ids, factor_values)
                next_key = next_key + ((f_id, 1),)
        for i in effect.index.to_numpy():
            next_key_values[i] = next_key

    # create the abstract state dict
    data_values = data[["state", "option", "reward", "next_state", "done",
                        "steps", "options_available", "next_options_available",
                        "state_info", "next_state_info"]].to_numpy()
    abstract_states = AbstractStateDict(key_names=key_names, key_values=key_values,
                                        next_key_values=next_key_values, data=data_values)

    return abstract_states


def build(transition_data: Optional[TransitionData],
          option_names: List[str],
          factors: List[Factor],
          abstract_states: Optional[AbstractStateDict] = None,
          transition_error_delta: Optional[float] = None,
          reward_error_delta: Optional[float] = None,
          min_transition_error: Optional[float] = None,
          min_reward_error: Optional[float] = None,
          changed_factors: Optional[Dict[int, List[Factor]]] = None,
          n_cluster_trials: int = 1,
          min_samples: int = 10,
          **kwargs) -> Tuple[AbstractMDP, AbstractStateDict, Dict]:
    assert len(factors) > 0
    msa = kwargs.get("msa", None)
    render = kwargs.get("render", lambda *args, **kwargs: 0)

    if abstract_states is None:
        assert transition_data is not None
        data = transition_data.data.copy()
        abstract_states = initialise_abstract_states(data, factors)

    mdp, statedict = build_model(abstract_states, option_names, factors, **kwargs)
    mdp_error = compute_mdp_error(abstract_states, factors, n_tests=20, changed_factors=changed_factors, msa=msa)
    factor_error = _derive_factor_errors(mdp_error)
    render(abstract_states, mdp, mdp_error, factor_error)

    # start refining
    improved = True
    it = 0
    while improved:
        improved = False
        t_err = _compute_transition_error(mdp_error)
        r_err = _compute_reward_error(mdp_error)
        for f_key, (t_e, r_e) in factor_error.items():
            if min_transition_error is not None and t_e < min_transition_error:
                continue
            if min_reward_error is not None and r_e < min_reward_error:
                continue

            i = f_key[0]
            fval = (f_key[1], f_key[2])
            f_i = factors[i]

            print(f"Refinement proposal={f_key}, t_err={t_e:.5f}, r_err={r_e:.5f}")
            best_states = None
            best_factor = None
            best_mdp_error = {}
            best_factor_error = {}
            best_r_err = np.inf
            best_t_err = np.inf

            for _ in range(n_cluster_trials):
                new_states = refine(fval, abstract_states, (i, f_i), _get_clusterer(**kwargs), min_samples)

                if new_states is None:
                    print(f"Could not refine {f_i}.")
                    continue

                new_mdp_error = compute_mdp_error(new_states, factors, n_tests=20,
                                                  changed_factors=changed_factors, msa=msa)
                new_factor_error = _derive_factor_errors(new_mdp_error)
                new_t_err = _compute_transition_error(new_mdp_error)
                new_r_err = _compute_reward_error(new_mdp_error)
                if new_t_err < best_t_err:
                    best_states = new_states
                    best_factor = f_i
                    best_mdp_error = new_mdp_error
                    best_factor_error = new_factor_error
                    best_t_err = new_t_err
                    best_r_err = new_r_err

            if best_states is None:
                continue
            print(f"Factor {best_factor}: t_err={t_err:.5f}->{best_t_err:.5f}, r_err={r_err:.5f}->{best_r_err:.5f}")

            if transition_error_delta is not None and (t_err - best_t_err < transition_error_delta):
                continue

            if reward_error_delta is not None and (r_err - best_r_err < reward_error_delta):
                continue

            it += 1
            print(f"Refinement {it}. From {t_err:.5f}, {r_err:.5f}"
                  f" to {best_t_err:.5f}, {best_r_err:.5f}"
                  f" by refining state {f_key} at {best_factor}.")
            abstract_states = best_states
            mdp_error = best_mdp_error
            factor_error = best_factor_error
            mdp, _ = build_model(abstract_states, option_names, factors, **kwargs)
            render(abstract_states, mdp, mdp_error, factor_error)
            improved = True
            break

        mdp, statedict = build_model(abstract_states, option_names, factors, **kwargs)

    return mdp, abstract_states, mdp_error


def _derive_factor_errors(mdp_error):
    factor_errors = {}
    for s_key in mdp_error:
        t_err, r_err = mdp_error[s_key]
        for i, (f_i, r_i) in enumerate(s_key):
            k = (i, f_i, r_i)
            if k not in factor_errors:
                factor_errors[k] = (0.0, 0.0)

            te, re = factor_errors[k]
            factor_errors[k] = (t_err+te, r_err+re)
    factor_errors = sorted(factor_errors.items(), key=lambda x: (x[1][0], x[1][1]), reverse=True)
    factor_errors = {k: v for k, v in factor_errors}
    return factor_errors


def _compute_transition_error(mdp_error):
    return sum(x_i[0] for x_i in mdp_error.values())


def _compute_reward_error(mdp_error):
    return sum(x_i[1] for x_i in mdp_error.values())


def _get_clusterer(**kwargs) -> Clusterer:
    clusterer = KMeans()
    if kwargs.get("interactive_debug", False):
        line = input("Enter clusterer: ")
        temp = line.split(" ")
        if temp[0][0] == "k":
            if len(temp) == 2:
                clusterer = KMeans(n_clusters=int(temp[1]))
            else:
                clusterer = KMeans()
        elif temp[0][0] == "d":
            if len(temp) == 2:
                clusterer = DBSCAN(eps=float(temp[1]))
            else:
                clusterer = DBSCAN()
        else:
            print("Invalid clusterer, using 2-means")
    return clusterer


def _get_factors(data: pd.DataFrame):
    n_vars = len(data["state"].iloc[0])
    modified_by = {i: [] for i in range(n_vars)}
    unique_masks = {}
    masks = data["mask"].values
    for mask in masks:
        mask_vec = tuple(sorted(mask))
        if mask_vec not in unique_masks:
            mask_id = len(unique_masks)
            unique_masks[mask_vec] = mask_id
            for var in mask:
                modified_by[var].append(mask_id)

    factors = []
    options_with_same_factors = []
    for var in range(n_vars):
        found = False
        for i in range(len(factors)):
            if tuple(options_with_same_factors[i]) == tuple(modified_by[var]):
                factors[i].append(var)
                found = True
                break

        if not found:
            factors.append([var])
            options_with_same_factors.append(modified_by[var])

    factors = [Factor(f"f{i}", x) for i, x in enumerate(factors)]
    return factors


def _get_modified_factors(n_options: int, factors: List[Factor], data: pd.DataFrame) -> Dict[int, List[Factor]]:
    option_partitions = data.groupby("option")
    change_counts = np.zeros((n_options, len(factors)), dtype=float)
    for option, partition in option_partitions:
        masks = partition["mask"].values
        N = len(masks)
        for m_i in masks:
            m_i = set(m_i)
            for i, f_i in enumerate(factors):
                if set(f_i.variables).issubset(m_i):
                    change_counts[option, i] += 1  # type: ignore
        change_counts[option] = change_counts[option] / N  # type: ignore
    change_counts[change_counts > 0.5] = 1
    change_counts[change_counts <= 0.5] = 0
    change_counts = change_counts.astype(int)
    changed_factors = {}
    for i in range(n_options):
        changed_factors[i] = []
        for j, f_j in enumerate(factors):
            if change_counts[i, j] == 1:
                changed_factors[i].append(f_j)
    return changed_factors


def _create_factor_values_from_key(factor_key: Tuple[Tuple[int, int], ...],
                                   factors: List[Factor]) -> List[FactorValue]:
    factor_values = []
    for i, (f_i, f_r) in enumerate(factor_key):
        factor_values.append(FactorValue(f_i, f_r, factors[i]))
    return factor_values


def _get_states_from_ndarray(samples: NDArray[np.object_]) -> \
        Tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]:
    s = samples[:, 0].tolist()
    s_prime = samples[:, 3].tolist()
    if len(s) > 0:
        s = np.stack(s)
    if len(s_prime) > 0:
        s_prime = np.stack(s_prime)
    reward = samples[:, 2].astype(float)
    return s, s_prime, reward
