"""Partition of a dataset of state transitions guided by synthesis of
local neural models."""
from dataclasses import dataclass
from swmpo.transition import Transition
from swmpo.transition import get_vector
from swmpo.model import get_input_output_size
from swmpo.transition_normalization import VectorStats
from swmpo.transition_normalization import get_vector_stats
from swmpo.transition_normalization import get_normalized_vector
from swmpo.transition_prunning.island_prunning import prune_short_transitions
from swmpo.model import get_relu_mlp
from sklearn.preprocessing import StandardScaler
from functools import cached_property
from itertools import product
import sklearn.cluster
import random
import torch
from torch.distributions.multivariate_normal import MultivariateNormal
import umap

# Avoid pytorch from doing threading. This is so that the script doesn't
# take over the computer's resources. You can remove these lines if not running
# on a lab computer.
torch.set_num_threads(1)


@dataclass
class StatePartitionItem:
    """A partition of a dataset of partitions."""
    local_model: torch.nn.Module
    subset: list[Transition]
    hidden_sizes: list[int]

    @cached_property
    def transition_vectors_as_set_of_tuples(self) -> set[tuple[float, ...]]:
        vectors = set[tuple[float, ...]]()
        for transition in self.subset:
            x = get_vector(transition)
            vectors.add(tuple(x.tolist()))
        return vectors


class PartitionSortingError(Exception):
    pass


def item_contains_transition(
        item: StatePartitionItem,
        transition: Transition,
        ) -> bool:
    """Return whether the transition appears
    in the partition item."""
    x = tuple(get_vector(transition).tolist())
    return x in item.transition_vectors_as_set_of_tuples


def get_initial_transition_n(
        item: StatePartitionItem,
        episodes: list[list[Transition]],
        ) -> int:
    """Return the number of times a transition occurs in the partition item."""
    # Extract initial transitions
    initial_transitions = [
        episode[0]
        for episode in episodes
        if len(episode) > 0
    ]

    # Filter-in the initial transitions that
    # appear in the partition item
    occurrences = [
        initial_transition
        for initial_transition in initial_transitions
        if item_contains_transition(
            item=item,
            transition=initial_transition
        )
    ]
    return len(occurrences)


def get_sorted_partition(
        partition: list[StatePartitionItem],
        episodes: list[list[Transition]],
        ) -> list[StatePartitionItem]:
    """Sort the partition so that the partition item with the most
    initial transitions is first."""
    # Identify first item
    sorted_partition = list(reversed(sorted(
        partition,
        key=lambda item: get_initial_transition_n(
            item=item,
            episodes=episodes
        )
    )))
    return sorted_partition


def get_partition_modes(
    trajectory: list[Transition],
    partition: list[StatePartitionItem],
) -> list[int]:
    """Return the list of indices of each transition in the trajectory."""
    modes = list()
    for transition in trajectory:
        index = None
        for i, item in enumerate(partition):
            if item_contains_transition(item, transition):
                index = i
        assert index is not None, "Partition doesn't contain transition!"
        modes.append(index)
    return modes


def get_optimized_model(
        transitions: list[Transition],
        hidden_sizes: list[int],
        learning_rate: float,
        iter_n: int,
        dt: float,
        seed: str,
        batch_size: int,
        device: str,
        verbose: bool,
        ) -> torch.nn.Module:
    """Helper function to optimize the partition ensemble by error
    weighting."""
    _random = random.Random(seed)

    # Find size of the states
    input_size, output_size = get_input_output_size(transitions[0])

    # Initialize local models
    model = get_relu_mlp(
        input_size=input_size,
        hidden_sizes=hidden_sizes,
        output_size=output_size,
        seed=str(_random.random()),
    ).to(device).train(True)

    # Initialize optimization algorithm
    parameters = list()
    parameters.extend(model.parameters())
    optimizer = torch.optim.Adam(
        params=parameters,
        lr=learning_rate,
    )
    optimizer.zero_grad()

    # Build regression targets
    X_list = list()
    Y_target_list = list()
    for transition in transitions:
        x = torch.cat([transition.source_state, transition.action])
        X_list.append(x)
        target = transition.next_state
        Y_target_list.append(target)
    X = torch.stack(X_list).clone().detach()
    state_size = len(transitions[0].source_state)
    X_state = X[:, :state_size]
    Y_target = torch.stack(Y_target_list).clone().detach()

    # Run gradient descent
    loss_log = list()
    indices = list(range(len(X)))
    for iter_i in range(iter_n):

        # Sample batch
        i = _random.sample(indices, k=1)[0]
        batch_is = list(
            range(min(i, len(indices)-2), min(len(indices), i+batch_size))
        )

        X_batch = X[batch_is]
        X_state_batch = X_state[batch_is]
        Y_target_batch = Y_target[batch_is]
        X_batch = X_batch.to(device=device)
        X_state_batch = X_state_batch.to(device=device)
        Y_target_batch = Y_target_batch.to(device=device)

        # Step model. We are assuming data comes from an Euler-integrated
        # simulation: x_{t+1} = x_t + f(x)*dt
        # We want to approximate f(x)
        X_dot = model(X_batch)
        Y_predicted_per_model = X_state_batch + X_dot*dt

        # Get model errors
        errors = (Y_predicted_per_model - Y_target_batch).norm(dim=1)

        # Aggregate weighted errors
        loss = errors.mean()

        # Step optimization algorithm
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Log losses
        loss_log.append(loss.item())

        if verbose:
            print(f"[get_optimized_error_partition] Iter: {iter_i}/{iter_n}; Loss: {loss.item()}")

    return model.eval().cpu()


@dataclass
class ModeWorldModel:
    encoder: torch.nn.Module
    decoder: torch.nn.Module
    state_normalization: VectorStats
    action_normalization: VectorStats
    loss_log: list[float]


def get_mode_vector(
    transition: Transition,
    mode_world_model: ModeWorldModel,
    device: str,
) -> list[float]:
    """Get the embedded representation of the given transition."""
    st = get_normalized_vector(
        transition.source_state,
        mode_world_model.state_normalization,
    )
    at = get_normalized_vector(
        transition.action,
        mode_world_model.action_normalization,
    )
    stp1 = get_normalized_vector(
        transition.next_state,
        mode_world_model.state_normalization,
    )
    sast = torch.cat([st, at, stp1]).to(device)
    encoded_vector = mode_world_model.encoder(sast.unsqueeze(0)).squeeze()
    return encoded_vector.detach().cpu().numpy().tolist()


def get_clusters(
    mode_world_model: ModeWorldModel,
    trajectories: list[list[Transition]],
    cluster_n: int,
    min_island_size: int,
    dimensionality_reduce: int | None,
    device: str,
) -> list[set[tuple[int, int]]]:
    """Partition the given dataset of transitions into disjoint subsets.
    The returned sets contain the indices of the transitions in the set."""
    # Bookkeeping
    vector_indices = list[tuple[int, int]]()
    location_index = dict()
    for i, trajectory in enumerate(trajectories):
        for j, _ in enumerate(trajectory):
            location = (i, j)
            index = len(vector_indices)
            vector_indices.append(location)
            location_index[location] = index

    # Get the latent vector for each transition
    encoded_vectors = list()
    for trajectory in trajectories:
        for transition in trajectory:
            embedding = get_mode_vector(
                transition,
                mode_world_model=mode_world_model,
                device=device,
            )
            encoded_vectors.append(embedding)
    X = torch.tensor(encoded_vectors)

    # Normalize embeddings
    X = StandardScaler().fit_transform(X)
    if dimensionality_reduce is not None:
        reducer = umap.UMAP(n_components=dimensionality_reduce)
        X = reducer.fit_transform(X)

    # Cluster latent vectors
    cluster = sklearn.cluster.KMeans(
        n_clusters=cluster_n,
    )
    labels = cluster.fit_predict(X)

    for trajectory in trajectories:
        assert len(trajectory) > 0

    # Prune short transitions
    new_labels = list(labels)
    for i, trajectory in enumerate(trajectories):
        # Reconstruct sequence of assigned modes
        modes = list()
        for j, transition in enumerate(trajectory):
            location = (i, j)
            index = location_index[location]
            mode = labels[index]
            modes.append(mode)

        # Prune sequence of modes
        new_modes = prune_short_transitions(modes, min_island_size)

        # Add new labels
        for j, new_mode in enumerate(new_modes):
            location = (i, j)
            index = location_index[location]
            new_labels[index] = new_mode
    labels = new_labels

    # Assemble clusters
    clusters = [
        set[tuple[int, int]]()
        for _ in range(cluster_n)
    ]
    for i, cluster_i in enumerate(labels):
        location = vector_indices[i]
        clusters[cluster_i].add(location)

    # Remove empty clusters
    clusters = [
        cluster
        for cluster in clusters
        if len(cluster) > 0
    ]
    return clusters


def get_mean_information_content(
    X: torch.tensor,  # shape (vector_n, feature_n)
    device: str,
) -> torch.Tensor:
    """Return the mean information content of the given set of vectors
    with respect to a multi-variate Gaussian distribution centered at 0
    with unit variance.

    We assume features are independent.
    """
    # Get the log probability of each vector
    _, feature_n = X.size()
    loc = torch.zeros((feature_n,), device=device)
    scale = torch.ones((feature_n,), device=device)
    # Because we assume the variables are independent, then we use
    # a diagonal covariance matrix
    covariance_matrix = torch.diag(scale)
    n = MultivariateNormal(
        loc=loc,
        covariance_matrix=covariance_matrix,
    )
    log_probs = n.log_prob(X)
    return (-log_probs).sum().mean()


def get_distribution(X: torch.Tensor) -> MultivariateNormal:
    """Return the Gaussian that maximizes the likelihood of the given
    list of vectors. X is assumed to be of shape (sample_n, feature_n).

    (This assumes variables
    """
    sample_n, feature_n = X.size()
    features = X.T
    feature_stdevs, feature_means = torch.std_mean(features, dim=1)

    # Because we assume the variables are independent, then we use
    # a diagonal covariance matrix
    covariance_matrix = torch.diag(feature_stdevs)

    assert tuple(feature_means.size()) == (feature_n,)
    assert tuple(covariance_matrix.size()) == (feature_n, feature_n)
    n = MultivariateNormal(
        loc=feature_means,
        covariance_matrix=covariance_matrix,
    )
    return n


def get_mutual_information(
    X: torch.Tensor,
    Y: torch.Tensor,
    mini_batch_size: int,
    seed: str,
) -> torch.Tensor:
    """Approximate the mutual information between X and Y.

    This assumes that the distributions P_X, P_Y and P_(X, Y) are
    Gaussians.

    X and Y are assumed to be of shape (sample_n, features1_n) and
    (sample_n, features2_n) respectively.
    """
    # Characterize P(X, Y), P(X) and P(Y)
    PX = get_distribution(X)
    PY = get_distribution(Y)
    XY = torch.cat([X, Y], dim=1)
    PXY = get_distribution(XY)

    # Formula is
    # I(X, Y) = \int x in X \int y in Y P(x, y) * log(P(x, y)/(P(x)*P(y))) dxdy
    #         = \int x in X \int y in Y P(x, y) * (log(P(x, y)) - log(P(x)*log(P(y)))) dxdy
    #         = \int x in X \int y in Y P(x, y) * (log(P(x, y)) - (log(P(x)) + log(P(y)))) dxdy
    #
    # Making \int and log(P(-)) notation more concise
    #
    #         = int x int y  P(x, y) * (logP(x, y) - (logP(x) + logP(y))) dxdy
    #         = int x int y  P(x, y) * (logP(x, y) - logP(x) - logP(y)) dxdy

    log_pX = PX.log_prob(X)
    log_pY = PY.log_prob(Y)

    # Speed:
    # This was really slow, so it ended up looking a bit confusing when
    # I tried to maximize the number of operations done with pytorch
    # to leverage SIMD
    # But it still wasn't fast enough, so I will sample a subset of the
    # pairs over which the integral is performed.
    _random = random.Random(seed)
    ijs = list(product(range(len(X)), range(len(Y))))
    minibatch_ijs = _random.sample(
        ijs,
        k=min(mini_batch_size, len(ijs))
    )

    pairs = list()
    Slog_pX_list = list()
    Slog_pY_list = list()
    for i, j in minibatch_ijs:
        Slog_pX_list.append(log_pX[i])
        Slog_pY_list.append(log_pY[j])
        pairs.append(torch.cat([X[i], Y[j]]))
    Slog_pX = torch.stack(Slog_pX_list)
    Slog_pY = torch.stack(Slog_pY_list)
    log_pXY = PXY.log_prob(torch.stack(pairs))
    pXY = log_pXY.exp()

    assert Slog_pX.size() == Slog_pY.size()
    assert Slog_pY.size() == log_pXY.size()
    assert log_pXY.size() == pXY.size()

    IXY = pXY * (log_pXY - Slog_pX - Slog_pY)
    # /Speed

    return torch.mean(IXY)


def get_predictive_residual_encoder(
    trajectories: list[list[Transition]],
    hidden_sizes: list[int],
    latent_size: int,
    learning_rate: float,
    iter_n: int,
    dt: float,
    seed: str,
    batch_size: int,
    information_content_regularization_scale: float,
    mutual_information_regularization_scale: float,
    mutual_information_mini_batch_size: int,
    device: str,
    verbose: bool,
) -> ModeWorldModel:
    """Return an mode_world_model optimized on the given dataset of
    transitions."""
    _random = random.Random(seed)

    # Find size of the states
    state_size = len(trajectories[0][0].source_state)
    action_size = len(trajectories[0][0].action)

    # Initialize local models
    encoder = get_relu_mlp(
        input_size=state_size+action_size+state_size,
        hidden_sizes=hidden_sizes,
        output_size=latent_size,
        seed=str(_random.random()),
    ).to(device).train(True)
    decoder = get_relu_mlp(
        input_size=latent_size+state_size+action_size,
        hidden_sizes=hidden_sizes,
        output_size=state_size,
        seed=str(_random.random()),
    ).to(device).train(True)

    # Initialize optimization algorithm
    parameters = list()
    parameters.extend(encoder.parameters())
    parameters.extend(decoder.parameters())
    optimizer = torch.optim.Adam(
        params=parameters,
        lr=learning_rate,
    )
    optimizer.zero_grad()

    # Build regression targets
    # tm1 is "t minus 1". E.g., stm1 = s{t-1}
    # Similarly, tp1 is "t plus 1".
    St_list = list()
    At_list = list()
    Stm1_list = list()
    Atm1_list = list()
    Stp1_list = list()
    for trajectory in trajectories:
        for t, transition in enumerate(trajectory[1:]):
            st = trajectory[t].source_state
            at = trajectory[t].action
            stm1 = trajectory[t-1].source_state
            atm1 = trajectory[t-1].action
            stp1 = trajectory[t].next_state

            St_list.append(st.tolist())
            At_list.append(at.tolist())
            Stm1_list.append(stm1.tolist())
            Atm1_list.append(atm1.tolist())
            Stp1_list.append(stp1.tolist())

    state_normalization = get_vector_stats(St_list)
    action_normalization = get_vector_stats(At_list)

    St_list = [
        get_normalized_vector(torch.tensor(x), state_normalization).tolist()
        for x in St_list
    ]
    At_list = [
        get_normalized_vector(torch.tensor(x), action_normalization).tolist()
        for x in At_list
    ]
    Stm1_list = [
        get_normalized_vector(torch.tensor(x), state_normalization).tolist()
        for x in Stm1_list
    ]
    Atm1_list = [
        get_normalized_vector(torch.tensor(x), action_normalization).tolist()
        for x in Atm1_list
    ]
    Stp1_list = [
        get_normalized_vector(torch.tensor(x), state_normalization).tolist()
        for x in Stp1_list
    ]

    St = torch.tensor(St_list)
    At = torch.tensor(At_list)
    Stm1 = torch.tensor(Stm1_list)
    Atm1 = torch.tensor(Atm1_list)
    Stp1 = torch.tensor(Stp1_list)

    # Run gradient descent
    loss_log = list()
    indices = list(range(len(St)))
    print(f"[get_predictive_residual_mode] Dataset size: {len(indices)}")

    for iter_i in range(iter_n):

        # Sample batch
        batch_is = _random.sample(indices, k=min(batch_size, len(indices)))

        St_batch = St[batch_is].to(device=device)
        At_batch = At[batch_is].to(device=device)
        Stm1_batch = Stm1[batch_is].to(device=device)
        Atm1_batch = Atm1[batch_is].to(device=device)
        Stp1_batch = Stp1[batch_is].to(device=device)

        # Get joint variable SAS_{t} = (S_{t-1}, A_{t-1}, S_{t})
        SASt_batch = torch.cat([Stm1_batch, Atm1_batch, St_batch], dim=1)

        # Get latent mode variable M_{t}
        Mt_batch = encoder(SASt_batch)

        # Get joint variable MA_{t} = (M_{t}, A_{t})
        MAt_batch = torch.cat([Mt_batch, St_batch, At_batch], dim=1)

        # Get model output
        DeltaSt_batch_predicted = decoder(MAt_batch)

        # Get model errors
        Stp1_batch_predicted = St_batch + DeltaSt_batch_predicted*dt
        errors = (Stp1_batch - Stp1_batch_predicted)

        # Get information bottleneck regularization term
        information_content = information_content_regularization_scale * \
            get_mean_information_content(MAt_batch, device)

        # Get mutual information bottleneck regularization term
        mutual_information = mutual_information_regularization_scale *\
            get_mutual_information(
                MAt_batch,
                Stp1_batch,
                seed=str(_random.random()),
                mini_batch_size=mutual_information_mini_batch_size,
            )

        # Aggregate weighted errors
        loss = errors.norm(dim=1).mean()
        loss = loss + information_content
        loss = loss + mutual_information

        # Step optimization algorithm
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Log losses
        loss_log.append(loss.item())

        if verbose:
            print(f"[get_predictive_residual_encoder] Iter: {iter_i}/{iter_n}; Loss: {loss.item()}")

    mode_world_model = ModeWorldModel(
        encoder=encoder.eval(),
        decoder=decoder.eval(),
        state_normalization=state_normalization,
        action_normalization=action_normalization,
        loss_log=loss_log,
    )
    return mode_world_model


@dataclass
class OptimizationResult:
    partition: list[StatePartitionItem]
    loss_log: list[float]
    mode_world_model: ModeWorldModel


def get_optimized_error_partition(
    trajectories: list[list[Transition]],
    hidden_sizes: list[int],
    learning_rate: float,
    latent_size: int,
    iter_n: int,
    mode_model_iter_n: int,
    clustering_dimensionality_reduce: int | None,
    clustering_information_content_regularization_scale: float,
    clustering_mutual_information_regularization_scale: float,
    dt: float,
    size: int,
    seed: str,
    min_island_size: int,
    batch_size: int,
    mutual_information_mini_batch_size: int,
    device: str,
    verbose: bool,
) -> OptimizationResult:
    """Helper function to optimize the partition ensemble by error weighting."""
    _random = random.Random(seed)

    # Normalize trajectory vectors
    if verbose:
        print("Normalizing trajectories for mode_world_model")

    # Train mode_world_model
    mode_world_model = get_predictive_residual_encoder(
        trajectories=trajectories,
        hidden_sizes=hidden_sizes,
        learning_rate=learning_rate,
        latent_size=latent_size,
        iter_n=iter_n,
        seed=str(_random.random()),
        dt=dt,
        batch_size=batch_size,
        information_content_regularization_scale=clustering_information_content_regularization_scale,
        mutual_information_regularization_scale=clustering_mutual_information_regularization_scale,
        mutual_information_mini_batch_size=mutual_information_mini_batch_size,
        device=device,
        verbose=verbose,
    )

    # Cluster latent state
    clusters = get_clusters(
        mode_world_model=mode_world_model,
        trajectories=trajectories,
        cluster_n=size,
        min_island_size=min_island_size,
        dimensionality_reduce=clustering_dimensionality_reduce,
        device=device,
    )

    # Assemble partition items
    partition = list[StatePartitionItem]()
    for cluster in clusters:
        subset = [
            trajectories[i][j]
            for (i, j) in cluster
        ]
        model = get_optimized_model(
            transitions=subset,
            hidden_sizes=hidden_sizes,
            learning_rate=learning_rate,
            iter_n=mode_model_iter_n,
            dt=dt,
            seed=str(_random.random()),
            batch_size=batch_size,
            device=device,
            verbose=verbose,
        )
        partition_item = StatePartitionItem(
            local_model=model,
            subset=subset,
            hidden_sizes=hidden_sizes,
        )
        partition.append(partition_item)

    # Assemble result
    loss_log = mode_world_model.loss_log
    optimization_result = OptimizationResult(
        partition=partition,
        loss_log=loss_log,
        mode_world_model=mode_world_model,
    )
    return optimization_result


def get_partition(
    episodes: list[list[Transition]],
    hidden_sizes: list[int],
    learning_rate: float,
    latent_size: int,
    optimization_iter_n: int,
    mode_model_iter_n: int,
    clustering_dimensionality_reduce: int | None,
    clustering_information_content_regularization_scale: float,
    clustering_mutual_information_regularization_scale: float,
    dt: float,
    size: int,
    min_island_size: int,
    seed: str,
    batch_size: int,
    mutual_information_mini_batch_size: int,
    device: str,
    verbose: bool,
) -> OptimizationResult:
    """Returns a partition of the set of transitions in the given episodes.
    Each subset of the partition has a corresponding neural model of the
    dynamics in that subset.

    The returned state machine will be sorted so that the first item contains
    the first transition in the input episodes. If that is not possible,
    `PartitionSortingError` will be raised.
    """
    _random = random.Random(seed)

    # Optimize smooth partition
    optimization_result = get_optimized_error_partition(
        trajectories=episodes,
        hidden_sizes=hidden_sizes,
        learning_rate=learning_rate,
        iter_n=optimization_iter_n,
        dt=dt,
        latent_size=latent_size,
        mode_model_iter_n=mode_model_iter_n,
        clustering_dimensionality_reduce=clustering_dimensionality_reduce,
        clustering_information_content_regularization_scale=clustering_information_content_regularization_scale,
        clustering_mutual_information_regularization_scale=clustering_mutual_information_regularization_scale,
        size=size,
        verbose=verbose,
        batch_size=batch_size,
        mutual_information_mini_batch_size=mutual_information_mini_batch_size,
        min_island_size=min_island_size,
        device=device,
        seed=str(_random.random()),
    )

    # Sort state partition so that the first item contains the first
    # transition of the episodes
    state_partition = get_sorted_partition(
        episodes=episodes,
        partition=optimization_result.partition,
    )

    optimization_result = OptimizationResult(
        partition=state_partition,
        loss_log=optimization_result.loss_log,
        mode_world_model=optimization_result.mode_world_model,
    )
    return optimization_result
