"""Normalize datasets of trajectories."""
from dataclasses import dataclass
from swmpo.transition import Transition
import statistics
import torch


@dataclass
class VectorStats:
    means: list[float]
    maxs: list[float]
    mins: list[float]
    stdevs: list[float]


def get_vector_stats(vectors: list[list[float]]) -> VectorStats:
    feature_vals: list[list[float]] = torch.tensor(
        vectors
    ).transpose(0, 1).tolist()
    means = [
        statistics.mean(vals)
        for vals in feature_vals
    ]
    stdevs = [
        statistics.stdev(vals)
        for vals in feature_vals
    ]
    maxs = [
        max(vals)
        for vals in feature_vals
    ]
    mins = [
        min(vals)
        for vals in feature_vals
    ]
    normalization = VectorStats(
        means=means,
        maxs=maxs,
        mins=mins,
        stdevs=stdevs,
    )
    return normalization


def get_normalized_vector(
    vector: torch.Tensor,
    stats: VectorStats,
) -> torch.Tensor:
    vals = [
        (val-stats.means[i])/stats.stdevs[i]
        for i, val in enumerate(vector)
    ]
    norm_vec = torch.stack(vals)
    return norm_vec


@dataclass
class TransitionStatistics:
    """Per-feature statistics."""
    state_normalization: VectorStats
    action_normalization: VectorStats


def get_transition_statistics(
    transitions: list[Transition],
) -> TransitionStatistics:
    state_vectors = list()
    action_vectors = list()

    for transition in transitions:
        # We only use source state because next state will be a source
        # state of some other vector
        state_vectors.append(transition.source_state.tolist())
        action_vectors.append(transition.action.tolist())

    state_normalization = get_vector_stats(state_vectors)
    action_normalization = get_vector_stats(action_vectors)

    stats = TransitionStatistics(
        state_normalization=state_normalization,
        action_normalization=action_normalization,
    )
    return stats


def get_normalized_state(
    state: torch.Tensor,
    stats: TransitionStatistics,
) -> torch.Tensor:
    return get_normalized_vector(state, stats.state_normalization)


def get_normalized_action(
    action: torch.Tensor,
    stats: TransitionStatistics,
) -> torch.Tensor:
    return get_normalized_vector(action, stats.action_normalization)


def get_normalized_transition(
    transition: Transition,
    stats: TransitionStatistics,
) -> Transition:
    norm_transition = Transition(
        source_state=get_normalized_state(transition.source_state, stats),
        next_state=get_normalized_state(transition.next_state, stats),
        action=get_normalized_action(transition.action, stats),
    )
    return norm_transition


def get_normalized_trajectories(
    trajectories: list[list[Transition]],
    stats: TransitionStatistics,
) -> list[list[Transition]]:
    """Normalize trajectory vectors using dataset statistics."""
    normalized_trajectories = [
        [
            get_normalized_transition(
                transition,
                stats,
            )
            for transition in trajectory
        ]
        for trajectory in trajectories
    ]
    return normalized_trajectories
