"""
Transition model training and utilities.

This module contains functions for training transition models using both
tabular (MLE) and neural network (MLP) approaches.
"""

import copy
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

EPS = np.finfo(np.float32).eps


def train_transition_model_wrapper(trajectories, env, model="MLE", n_epochs=5, verbose=[]):
    """
    Wrapper function for training transition models with different methods.

    Args:
        offline_trajs: List of offline trajectories
        env_true: True environment (for initialization)
        model: Model type ("MLE", "linear_classifier", "MLP")

    Returns:
        tuple: (env_updated, transition_model, losses)
    """
    if model == "MLE":
        return train_transition_model_MLE(trajectories, env, verbose)
    elif model == "linear_classifier":
        return train_transition_model_MLP(
            trajectories, env, n_epochs, use_mlp=False, verbose=verbose
        )
    elif model == "MLP":
        return train_transition_model_MLP(
            trajectories, env, n_epochs, use_mlp=True, verbose=verbose
        )
    else:
        raise ValueError(f"Unknown transition model type: {model}")


def train_transition_model_MLE(trajectories, env, verbose):
    """
    Train transition model using Maximum Likelihood Estimation from dataset of trajectories.

    Args:
        trajectories: List of trajectories, each [s0,a0,r0,s1,a1,r1,...]
        env: Environment with N_states, N_actions attributes
        verbose: Verbosity options

    Returns:
        tuple:
         env_MLE: Environment with updated transition matrix
         transitions: Transition matrix
         transition_counts: Counts of transitions
    """
    N_states = env.N_states
    N_actions = env.N_actions

    transition_counts = np.zeros((N_actions, N_states, N_states))

    # count transitions from trajectories
    for traj in trajectories:
        for i in range(0, len(traj) - 3, 3):  # step by 3 to get (s,a,r,s')
            state = traj[i]
            action = traj[i + 1]  # next state is at i+3
            next_state = traj[i + 3]

            if 0 <= state < N_states and 0 <= action < N_actions and 0 <= next_state < N_states:
                transition_counts[action, state, next_state] += 1

    # convert counts to probabilities (MLE)
    transitions = np.zeros((N_actions, N_states, N_states))
    for a in range(N_actions):
        for s in range(N_states):
            state_action_total_count = np.sum(transition_counts[a, s, :])
            if state_action_total_count > 0:
                transitions[a, s, :] = transition_counts[a, s, :] / state_action_total_count
            else:
                # if state-action pair never observed, uniform distribution
                transitions[a, s, :] = 1.0 / N_states

    # create updated environment
    env_MLE = copy.deepcopy(env)
    env_MLE.transitions = transitions
    if "full" in verbose:
        print(f"MLE transitions:\n{transitions}")

    return env_MLE, transitions, transition_counts


def train_transition_model_MLP(trajectories, env, n_epochs, use_mlp=False, verbose=[]):
    """
    Train transition model using neural networks.

    Args:
        offline_trajs: List of offline trajectories
        env_true: True environment
        use_mlp: Whether to use MLP (True) or linear model (False)
        n_epochs: Number of training epochs
        verbose: List of verbosity options

    Returns:
        tuple:
          env_updated: Environment with updated transition matrix
          transition_model: Trained transition model
          losses: Losses per epoch
    """
    if "full" in verbose:
        print(f"Training {'MLP' if use_mlp else 'linear'} transition model")

    dataset = TrajectoriesDataset(trajectories, env.N_states, env.N_actions, verbose)
    dataloader = DataLoader(dataset, batch_size=min(4, len(trajectories)), shuffle=True)

    transition_model = MLPTransitionModel(env.N_states, env.N_actions, use_mlp=use_mlp)
    optimizer = torch.optim.Adam(transition_model.parameters(), lr=0.01)

    transition_model, losses = train_model(
        transition_model, dataloader, optimizer, loss_fn_T_trajs, n_epochs, verbose
    )

    transitions = transition_model.extract_transition_matrix()

    env_trained = copy.deepcopy(env)
    env_trained.transitions = transitions

    if "full" in verbose:
        print(f"Transition model: {transition_model}")

    return env_trained, transition_model, losses


def sanity_check_transitions(env, fix=False, verbose=[]):
    """
    Check that transition probabilities sum to 1 and fix if needed.

    Args:
        env: Environment with transitions to check
        fix: Whether to fix invalid transitions

    Returns:
        Environment with valid transitions

    """
    if env.transitions is None:
        return env

    # using stricter tolerances that are closer to what np.random.choice expects
    row_sums = np.sum(env.transitions, axis=2)
    all_close = np.allclose(row_sums, 1, rtol=1e-10, atol=1e-10)

    if not all_close:
        if "full" in verbose:
            print("\nRows that don't sum to 1 detected.")
        bad_indices = np.where(~np.isclose(row_sums, 1, rtol=1e-10, atol=1e-10))
        for i in range(len(bad_indices[0])):
            a, s = bad_indices[0][i], bad_indices[1][i]
            if "full" in verbose:
                print(f"Action {a}, State {s}: sum = {row_sums[a, s]:.10f}")
                print(f"Row values: {env.transitions[a, s]}")

        if fix:
            for a in range(env.N_actions):
                for s in range(env.N_states):
                    row_sum = np.sum(env.transitions[a, s])
                    if row_sum > 0:
                        env.transitions[a, s] /= row_sum
                    else:
                        env.transitions[a, s] = 1.0 / env.N_states
            if "full" in verbose:  # verify fix worked
                new_row_sums = np.sum(env.transitions, axis=2)
                print(
                    f"Fixing transition model: all rows now sum to 1: {np.allclose(new_row_sums, 1, rtol=1e-10, atol=1e-10)}\nFixed transitions:\n{env.transitions}"
                )
        else:
            raise ValueError("Transition model rows must sum to 1")

    return env


class MLPTransitionModel(nn.Module):
    """
    Transition model that uses a 2-layer MLP or linear layer to predict next state probabilities.
    """

    def __init__(self, N_states, N_actions, use_mlp=False, hidden_dim=32):
        """
        Initialize MLP transition model.

        Args:
            N_states: Number of states
            N_actions: Number of actions
            use_mlp: Whether to use MLP (True) or linear model (False)
            hidden_dim: Hidden dimension for MLP
        """
        super(MLPTransitionModel, self).__init__()
        self.N_states = N_states
        self.N_actions = N_actions

        if use_mlp:  # 2-layer MLP with ReLU activation
            self.model = nn.Sequential(
                nn.Linear(N_states + N_actions, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, N_states),
            )
        else:  # linear model
            self.model = nn.Linear(N_states + N_actions, N_states, bias=False)
            torch.nn.init.kaiming_uniform_(self.model.weight)

    def get_device(self):
        """Get device of model parameters."""
        return next(self.parameters()).device

    def forward(self, state, action):
        """
        Predict next state probabilities p(s'|s,a).

        Args:
            state, action: one-hot encoded vectors of shape [N_states], [N_actions]

        Returns:
            Next state probability distribution, vector p(s'|s,a) of shape [N_states]
        """
        input_tensor = torch.cat((state, action), dim=-1)
        next_state = self.model(input_tensor)
        return next_state

    def extract_transition_matrix(self):
        transition_matrix = torch.zeros(self.N_actions, self.N_states, self.N_states)

        for action_idx in range(self.N_actions):
            for state_idx in range(self.N_states):
                state = torch.zeros(1, self.N_states)
                state[:, state_idx] = 1  # one-hot encoding for state j
                action = torch.zeros(1, self.N_actions)
                action[:, action_idx] = 1  # one-hot encoding for action i

                next_state_prob = nn.Softmax(dim=1)(self.forward(state, action))  # softmax(T(s,a))
                transition_matrix[action_idx, state_idx, :] = next_state_prob.squeeze()

        return transition_matrix.detach().numpy().astype(np.double)


class TabularTransitionModel(nn.Module):
    """
    Tabular transition model for discrete MDPs.

    Computes the logprob of a trajectory by multiplying triplet counts (s,a,s') with log transition matrix.
    """

    def __init__(self, N_states, N_actions, init_matrix=None):
        super(TabularTransitionModel, self).__init__()
        self.N_states = N_states
        self.N_actions = N_actions
        if init_matrix is None:
            init_matrix = torch.rand(N_actions, N_states, N_states)
        else:
            assert init_matrix.shape == (N_actions, N_states, N_states)
            init_matrix = torch.from_numpy(copy.deepcopy(init_matrix))
            init_matrix += EPS
        self.transition_matrix = nn.Parameter(init_matrix)
        self.normalize()

    def forward(self, triple_count):
        return torch.matmul(triple_count, self.transition_matrix.reshape(-1, 1))

    def normalize(
        self,
    ):
        with torch.no_grad():
            self.transition_matrix -= torch.logsumexp(self.transition_matrix, -1, keepdim=True)


class TrajectoriesDataset(Dataset):
    """
    Dataset for trajectory data. Individual datapoint is (s,a,r,s').
    """

    def __init__(self, trajs, N_states=5, N_actions=2, verbose=[]):
        if "full" in verbose:
            print(f"Creating dataset from {len(trajs)} trajectories. Trajs are:\n{trajs}")

        states = torch.tensor([traj[:-3:3] for traj in trajs]).view(-1)
        actions = torch.tensor([traj[1:-2:3] for traj in trajs]).view(-1)
        self.rewards = torch.tensor([traj[2:-1:3] for traj in trajs]).view(-1).float()
        next_states = torch.tensor([traj[3::3] for traj in trajs]).view(-1)

        # One-hot encode:
        self.states = nn.functional.one_hot(states, num_classes=N_states).float()
        self.actions = nn.functional.one_hot(actions, num_classes=N_actions).float()
        self.next_states = nn.functional.one_hot(next_states, num_classes=N_states).float()

    def __len__(self):
        return len(self.states)

    def __getitem__(self, idx):
        return (
            self.states[idx],
            self.actions[idx],
            self.rewards[idx],
            self.next_states[idx],
        )


def loss_fn_T_trajs(batch, transition_model):
    """Transition model: cross-entropy loss, to predict single successor state:
    L = CE(T(.|s,a), s')
    right now: s' (next_state) is one-hot vector, treated as 'class probabilities' target.
    TODO: documentation says nn.CrossEntropyLoss computation is faster if target is 'class index'.
    so change next_state to be just the correct idx.
    """
    state, action, _, next_state = batch  # all one-hot encoded.
    device = transition_model.get_device()
    state = state.to(device)
    action = action.to(device)
    next_state = next_state.to(device)
    pred_next_state_logit = transition_model(state, action)  # size (batch_size, N_states)
    return nn.CrossEntropyLoss()(pred_next_state_logit, next_state)


def train_model(model, dataloader, optimizer, loss_fn, n_epochs=10, verbose=[]):
    """
    Generic model training function.

    Args:
        model: Model to train
        dataloader: Data loader
        optimizer: Optimizer
        loss_fn: Loss function
        n_epochs: Number of epochs
        verbose: Verbosity options

    Returns:
        tuple: (trained_model, loss_per_epoch)
    """
    loss_per_epoch = []

    for epoch in range(n_epochs):
        losses = []
        for batch in dataloader:
            optimizer.zero_grad()
            loss = loss_fn(batch, model)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

        epoch_loss = np.mean(losses)
        loss_per_epoch.append(epoch_loss)

        if "full" in verbose:
            print(f"Epoch {epoch}, Loss: {epoch_loss:.4f}")

    return model, loss_per_epoch
