"""
Official implementation of the Bisimulator algorithm.
"""

from typing import Union, Literal

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import nevergrad as ng
import numpy as np
import os

from fair_gym import LendingEnv, CollegeAdmissionEnv, AcceptRejectAction

from agents.ppo import PPO
from utils.rollout_buffer import BisimulatorBuffer
from utils.env_utils import preprocess_lending_obs, preprocess_college_admission_obs


def layer_init(layer: nn.Module, std=np.sqrt(2), bias_const=0.0) -> nn.Module:
    """
    Initialize the weights of a layer with orthogonal initialization.
    
    Args:
        layer (nn.Module): The layer to initialize.
        std (float): The standard deviation of the weights.
        bias_const (float): The constant value of the bias.
        
    Returns:
        nn.Module: The initialized layer.
    """
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


def get_quantile_masks(states: torch.Tensor) -> torch.Tensor:
    """
    Returns the quantile masks for the states.
    
    Args:
        states (torch.Tensor): The states tensor.
        
    Returns:
        torch.Tensor: The quantile masks.
    """
    # Get quartiles
    quantiles = [0.25, 0.5, 0.75]
    quantiles = [torch.quantile(states, q, interpolation="linear") for q in quantiles]

    # Initialize the list to store the quarnile masks
    quantile_masks = []

    # Extract quantile-wise elements
    for i in range(len(quantiles) + 1):
        if i == 0:
            mask = states <= quantiles[i]
        elif i == len(quantiles):
            mask = states > quantiles[-1]
        else:
            mask = (states > quantiles[i - 1]) & (states <= quantiles[i])
        quantile_masks.append(mask.squeeze())

    return quantile_masks


def one_hot_encode(values: torch.Tensor, size: int) -> torch.Tensor:
    """
    One-hot encode the given value.

    Args:
        values (torch.Tensor): The values to one-hot encode.
        size (int): The size of the one-hot encoding.

    Returns:
        torch.Tensor: The one-hot encoding.
    """
    one_hot = torch.zeros(values.shape[0], size).to(values.device)
    one_hot[torch.arange(values.shape[0]), values.long()] = 1
    return one_hot


class Reward(nn.Module):
    """
    Models the reward function as a neural network.
    """
    
    def __init__(
        self, 
        state_without_group_dim: int,
        group_dim: int,
        n_actions: int,
        hidden_width: int
    ) -> None:
        """
        Initializes the Reward model.
        
        Args:
            state_without_group_dim (int): The state dimension without the group dimension.
            group_dim (int): The group dimension.
            n_actions (int): The number of actions.
            hidden_width (int): The hidden width of the neural network.
        """
        super().__init__()
        
        self.n_actions = n_actions
        self.net = nn.Sequential(
            layer_init(
                nn.Linear(
                    state_without_group_dim + group_dim + n_actions, hidden_width
                )
            ),
            nn.ReLU(),
            layer_init(nn.Linear(hidden_width, hidden_width)),
            nn.ReLU(),
            layer_init(nn.Linear(hidden_width, 1), std=1.0),
            nn.Sigmoid(),
        )

    def forward(
        self, state: torch.Tensor, group: torch.Tensor, action: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass of the reward model.
        
        Args:
            state (torch.Tensor): The current state.
            group (torch.Tensor): The group of the current state.
            action (torch.Tensor): The action taken from the current state.
        
        Returns:
            torch.Tensor: The predicted reward.
        """
        state = state.view(-1, state.size(-1))
        group = group.view(-1, group.size(-1))
        action = action.view(-1, 1)
        action = one_hot_encode(action, self.n_actions)
        
        x = torch.cat([state, group, action], dim=-1)
        return self.net(x)


class DynamicsModel(nn.Module):
    """
    Models the dynamics model as a discrete Gaussian distribution.
    """
    
    def __init__(
        self,
        state_without_group_dim: int, 
        group_dim: int, 
        n_actions: int,
        output_dim: int, 
        hidden_width: int
    ) -> None:
        """
        Initializes the DynamicsModel.
        
        Args:
            state_without_group_dim (int): The state dimension without the group dimension.
            group_dim (int): The group dimension.
            n_actions (int): The number of actions.
            output_dim (int): The output dimension.
            hidden_width (int): The hidden width of the neural network.
        """
        super().__init__()
        
        self.n_actions = n_actions
        self.fc_base = nn.Sequential(
            layer_init(
                nn.Linear(
                    state_without_group_dim + group_dim + n_actions, hidden_width
                )
            ),
            nn.ReLU(),
            layer_init(nn.Linear(hidden_width, hidden_width)),
            nn.ReLU(),
        )
        self.fc_mean = layer_init(nn.Linear(hidden_width, output_dim), std=1.0)
        self.fc_logvar = layer_init(nn.Linear(hidden_width, output_dim), std=1.0)
    
    def forward(
        self, state: torch.Tensor, group: torch.Tensor, action: torch.Tensor
    ) -> torch.Tensor:
        """
        Forward pass of the dynamics model.
        
        Args:
            state (torch.Tensor): The current state.
            group (torch.Tensor): The group of the current state.
            action (torch.Tensor): The action taken from the current state.
        
        Returns:
            torch.Tensor: The mean of the Gaussian distribution.
            torch.Tensor: The log-variance of the Gaussian distribution.
        """
        state = state.view(-1, state.size(-1))
        group = group.view(-1, group.size(-1))
        action = action.view(-1, 1)            
        action = one_hot_encode(action, self.n_actions)
        
        x = torch.cat([state, group, action], dim=-1)
        x = self.fc_base(x)
        mean = self.fc_mean(x)
        logvar = self.fc_logvar(x)
        return mean, logvar
    
    def sample_next_state(
        self, state: torch.Tensor, group: torch.Tensor, action: torch.Tensor
    ) -> torch.Tensor:
        """
        Samples the next state from the Gaussian distribution.
        
        Args:
            state (torch.Tensor): The current state.
            group (torch.Tensor): The group of the current state.
            action (torch.Tensor): The action taken from the current state.
        
        Returns:
            torch.Tensor: The sampled next state.
        """
        mean, logvar = self.forward(state, group, action)
        std = torch.exp(0.5 * logvar)
        gaussian_sample = mean + std * torch.randn_like(mean)
        discrete_sample = torch.round(gaussian_sample)
        # Straight-Through Estimator (STE)
        continuous_sample = gaussian_sample + (discrete_sample - gaussian_sample).detach()
        return continuous_sample
    

class Bisimulator:
    """
    Implements the Bisimulator algorithm for MDP fairness.
    """
    
    def __init__(
        self,
        state_without_group_dim: int,
        actual_next_state_dim: int,
        group_dim: int,
        n_actions: int,        
        hidden_width: int,
        env_name : Literal["lending", "college"] = "lending",
        learning_rate: float = 2.5e-4,
        final_learning_rate: float = 1e-4,
        use_anneal_lr: bool = True,
        batch_size: int = 512,
        gamma: float = 0.99,
        dyn_model_epochs: int = 400,
        start_dyn_opt_step: int = 20000,
        dyn_opt_iters: int = 400,
        max_episode_steps: int = 500,
        dyn_rollout_steps: int = 1000,
        rew_coef: float = 1.0,
        device: torch.device = torch.device("cpu"),
    ) -> None:
        """
        Initializes the Bisimulator algorithm.
        
        Args:
            state_without_group_dim (int): The state dimension without the group dimension.
            actual_next_state_dim (int): The actual next state dimension.
            group_dim (int): The group dimension.
            n_actions (int): The number of actions.
            env_name (Literal["lending", "college"]): The environment name.
            hidden_width (int): The hidden width of the neural networks.
            learning_rate (float): The learning rate.
            batch_size (int): The batch size for training the dynamics model.
            dyn_model_epochs (int): The number of epochs to train the dynamics model.
            start_dyn_opt_step (int): The step to start optimizing the MDP dynamics.
            dyn_opt_iters (int): The number of iterations to optimize the MDP dynamics.
            max_episode_steps (int): The maximum number of steps in an episode.
            dyn_rollout_steps (int): The number of steps to collect the dynamics rollout.
            rew_coef (float): The coefficient for the reward loss.s
            device (torch.device): The device to run the algorithm.
        """        
        assert group_dim == 2, "Only 2 groups are supported"

        self.state_without_group_dim = state_without_group_dim
        self.actual_next_state_dim = actual_next_state_dim
        self.group_dim = group_dim
        self.env_name = env_name
        self.hidden_width = hidden_width
        self.learning_rate = learning_rate
        self.final_learning_rate = final_learning_rate
        self.use_anneal_lr = use_anneal_lr
        self.gamma = gamma
        self.n_actions = n_actions
        self.batch_size = batch_size
        self.dyn_model_epochs = dyn_model_epochs
        self.start_dyn_opt_step = start_dyn_opt_step
        self.dyn_opt_iters = dyn_opt_iters
        self.max_episode_steps = max_episode_steps
        self.dyn_rollout_steps = dyn_rollout_steps
        self.rew_coef = rew_coef
        self.device = device

        self.reward = Reward(
            state_without_group_dim=state_without_group_dim,
            group_dim=group_dim,
            n_actions=n_actions,
            hidden_width=hidden_width,
        ).to(device)
        self.dynamics_model = DynamicsModel(
            state_without_group_dim=state_without_group_dim,
            group_dim=group_dim,
            n_actions=n_actions,
            output_dim=1, # The output of the dynamics model is the next state id
            hidden_width=hidden_width,
        ).to(device)

        self.reward_optimizer = optim.Adam(
            self.reward.parameters(), lr=learning_rate
        )
        self.dynamics_model_optimizer = optim.Adam(
            self.dynamics_model.parameters(), lr=learning_rate
        )
        
        # Nevergrad optimizer for optimizing the MDP dynamicss
        if env_name == "lending":
            self.parametrization = ng.p.Instrumentation(
                positive_credit_changes=ng.p.Array(init=np.ones(group_dim,)).set_bounds(lower=1, upper=3).set_integer_casting(),
                negative_credit_changes=ng.p.Array(init=np.ones(group_dim,)).set_bounds(lower=1, upper=3).set_integer_casting(),
            )
        elif env_name == "college":
            self.parametrization = ng.p.Instrumentation(
                score_changes_coef=ng.p.Array(init=np.ones(group_dim,)).set_bounds(lower=1, upper=3).set_integer_casting(),
            )
        else:
            raise ValueError("Unsupported environment")
        
        self.ng_optimizer = ng.optimizers.DiscreteOnePlusOne(
            parametrization=self.parametrization, budget=dyn_opt_iters,
        )
        self.optimized_mdp_dyn = False
    
    def anneal_lr(self, current_step: int, total_steps: int) -> None:
        if self.use_anneal_lr:
            lr = self.learning_rate - (self.learning_rate - self.final_learning_rate) * (current_step / total_steps)
            self.reward_optimizer.param_groups[0]["lr"] = lr
            self.dynamics_model_optimizer.param_groups[0]["lr"] = lr

    def get_reward(
        self, state: torch.Tensor, group: torch.Tensor, action: torch.Tensor
    ) -> torch.Tensor:
        """
        Returns the reward using the reward function.
        
        Args:
            state (torch.Tensor): The current state.
            group (torch.Tensor): The group of the current state.
            action (torch.Tensor): The action taken from the current state.
        
        Returns:
            torch.Tensor: The reward.
        """
        state = state.to(self.device)
        group = group.to(self.device)
        action = action.to(self.device)
        return self.reward(state, group, action)

    def update_reward(self, bisimulator_buffer: BisimulatorBuffer) -> dict[str, float]:
        """
        Updates the reward function using the bisimulation 
        loss between group-conditioned MDPs.
        
        Args:
            bisimulator_buffer (BisimulatorBuffer): The buffer containing the data.
        
        Returns:
            dict[str, float]: The metrics for the reward model training.
        """
        states, actions, groups, rewards, next_states = bisimulator_buffer.get_data()

        loss, count, total_min_length = 0, 0, 0

        # Reshape and convert one-hot vectors to scalars
        states = states.reshape(-1, self.state_without_group_dim)
        groups = groups.reshape(-1, self.group_dim)
        states_id = torch.argmax(states, dim=-1, keepdim=True).float()

        # Get indices of each group
        group_0_indices = torch.where(groups[:, 0] == 1)[0]
        group_1_indices = torch.where(groups[:, 1] == 1)[0]
        if group_0_indices.size(0) == 0 or group_1_indices.size(0) == 0:
            return {"losses/bisim_reward_loss": 0.0, "losses/bisim_avg_min_length": 0.0}
        
        # Get the quartiles for each group
        group_0_quantile_masks = get_quantile_masks(states_id[group_0_indices])
        group_1_quantile_masks = get_quantile_masks(states_id[group_1_indices])

        for i in range(len(group_0_quantile_masks)):
            states_group_0 = states[group_0_indices][group_0_quantile_masks[i]]
            states_group_1 = states[group_1_indices][group_1_quantile_masks[i]]

            groups_group_0 = groups[group_0_indices][group_0_quantile_masks[i]]
            groups_group_1 = groups[group_1_indices][group_1_quantile_masks[i]]

            actions_group_0 = actions[group_0_indices][group_0_quantile_masks[i]]
            actions_group_1 = actions[group_1_indices][group_1_quantile_masks[i]]

            rewards_group_0 = rewards[group_0_indices][group_0_quantile_masks[i]]
            rewards_group_1 = rewards[group_1_indices][group_1_quantile_masks[i]]

            for a in range(self.n_actions):
                group_0_selected = torch.where(actions_group_0 == a)[0]
                group_1_selected = torch.where(actions_group_1 == a)[0]
                min_length = min(len(group_0_selected), len(group_1_selected))
                if min_length == 0:
                    continue

                loss += F.smooth_l1_loss(
                    rewards_group_0[group_0_selected][:min_length]
                    + self.rew_coef * self.reward(
                        states_group_0[group_0_selected][:min_length],
                        groups_group_0[group_0_selected][:min_length],
                        actions_group_0[group_0_selected][:min_length],
                    ),
                    rewards_group_1[group_1_selected][:min_length]
                    + self.rew_coef * self.reward(
                        states_group_1[group_1_selected][:min_length],
                        groups_group_1[group_1_selected][:min_length],
                        actions_group_1[group_1_selected][:min_length],
                    ),
                )
                count += 1
                total_min_length += min_length

        if count > 0:
            loss /= count
            self.reward_optimizer.zero_grad()
            loss.backward()
            self.reward_optimizer.step()

            metrics = {
                "losses/bisim_reward_loss": loss.item(),
                "losses/bisim_avg_min_length": total_min_length / count,
            }
        else:
            metrics = {
                "losses/bisim_reward_loss": 0.0,
                "losses/bisim_avg_min_length": 0.0,
            }
        return metrics
    
    def train_dynamics_model(self, bisimulator_buffer: BisimulatorBuffer) -> dict[str, float]:
        """
        Trains the dynamics model using the MSE loss between the predicted next state
        and the actual next state. 
        
        Args:
            bisimulator_buffer (BisimulatorBuffer): The buffer containing the data.
        
        Returns:
            dict[str, float]: The metrics for the dynamics model training.
        """
        states, actions, groups, _, next_states = bisimulator_buffer.get_data()
        states = states.reshape(-1, self.state_without_group_dim)
        groups = groups.reshape(-1, self.group_dim)
        
        next_states_id = torch.argmax(next_states, dim=-1, keepdim=True).reshape(-1, 1).float()

        loss, total_loss, correct, total = 0, 0, 0, 0
        
        for i in range(self.dyn_model_epochs):
            for j in range(0, states.size(0), self.batch_size):
                states_batch = states[j:j+self.batch_size]
                actions_batch = actions[j:j+self.batch_size]
                groups_batch = groups[j:j+self.batch_size]
                next_states_batch = next_states_id[j:j+self.batch_size]
                
                predicted_next_states = self.dynamics_model.sample_next_state(states_batch, groups_batch, actions_batch)
                loss = F.mse_loss(predicted_next_states, next_states_batch)
                
                self.dynamics_model_optimizer.zero_grad()
                loss.backward()
                self.dynamics_model_optimizer.step()
                
                total_loss += loss.item()
                correct += (predicted_next_states == next_states_batch).sum().item()
                total += next_states_batch.size(0)
            
        metrics = {
            "losses/dynamics_model_loss": total_loss / self.dyn_model_epochs,
            "losses/dynamics_model_accuracy": correct / total,
        }
        return metrics
    
    def collect_rollout(self, env: Union[LendingEnv, CollegeAdmissionEnv], agent: PPO)-> BisimulatorBuffer:
        """
        Collects a rollout of the environment using the current policy.
        
        Args:
            env (Union[LendingEnv, CollegeAdmissionEnv]): The environment to collect the rollout.
            agent (PPO): The agent to collect the rollout.
        
        Returns:
            BisimulatorBuffer: The buffer containing the rollout data.
        """
        if self.env_name == "lending":
            prepocessor_fn = preprocess_lending_obs
        elif self.env_name == "college":
            prepocessor_fn = preprocess_college_admission_obs
        else:
            raise ValueError("Unsupported environment")
        
        buffer = BisimulatorBuffer(
            self.state_without_group_dim, 
            self.actual_next_state_dim,
            self.group_dim, 
            self.dyn_rollout_steps,
            self.device
        )
        episode_step = 0
        
        obs, _ = env.reset()
        next_done = torch.zeros(1)
        state, state_without_group, _, group = prepocessor_fn(obs)
        
        for step in range(0, self.dyn_rollout_steps):
            episode_step += 1
            
            # Option 1: Give loan to everyone
            action = torch.tensor([AcceptRejectAction.ACCEPT.value]).view(-1)
            
            # Option 2: Use the trained RL agent to select actions
            # with torch.no_grad():
            #     action, _, _, _ = agent.get_action_and_value(state)
            
            next_obs, env_reward, terminated, truncated, _ = env.step(action.cpu().numpy())
            done = terminated or truncated
            
            next_state, next_state_without_group, actual_next_state, next_group = prepocessor_fn(next_obs)
            env_reward = torch.tensor([env_reward]).view(-1)
            
            buffer.add(state_without_group, action, group, env_reward, actual_next_state)
            
            group = torch.Tensor(next_group)
            state = torch.Tensor(next_state)
            state_without_group = torch.Tensor(next_state_without_group)
            next_done = torch.Tensor([done])
            
            if next_done or episode_step == self.max_episode_steps:
                obs, _ = env.reset()
                state, state_without_group, _, group = prepocessor_fn(obs)
                episode_step = 0
        
        return buffer
    
    def update_dynamics(self, env: Union[LendingEnv, CollegeAdmissionEnv], agent: PPO, global_step: int):
        """
        Updates the dynamics parameters of the MDP using the 
        bisimulation loss between group-conditioned MDPs.
        
        Args:
            env (Union[LendingEnv, CollegeAdmissionEnv]): The environment.
            agent (PPO): The agent.
            global_step (int): The global step.
        
        Returns:
            
        """
        metrics = {}
        
        def compute_bisim_dyn_loss(buffer):
            # Evaluate the bisimulation metric between subgroups
            states, actions, groups, _, _ = buffer.get_data()
            loss = 0

            # Reshape and convert one-hot vectors to scalars
            states = states.reshape(-1, self.state_without_group_dim)
            groups = groups.reshape(-1, self.group_dim)
            states_id = torch.argmax(states, dim=-1, keepdim=True).float()

            # Get indices of each group
            group_0_indices = torch.where(groups[:, 0] == 1)[0]
            group_1_indices = torch.where(groups[:, 1] == 1)[0]

            # Get the quartiles for each group
            group_0_quantile_masks = get_quantile_masks(states_id[group_0_indices])
            group_1_quantile_masks = get_quantile_masks(states_id[group_1_indices])
            if group_0_indices.size(0) == 0 or group_1_indices.size(0) == 0:
                return 0.0

            for i in range(len(group_0_quantile_masks)):
                states_group_0 = states[group_0_indices][group_0_quantile_masks[i]]
                states_group_1 = states[group_1_indices][group_1_quantile_masks[i]]

                groups_group_0 = groups[group_0_indices][group_0_quantile_masks[i]]
                groups_group_1 = groups[group_1_indices][group_1_quantile_masks[i]]

                actions_group_0 = actions[group_0_indices][group_0_quantile_masks[i]]
                actions_group_1 = actions[group_1_indices][group_1_quantile_masks[i]]
                
                for a in range(self.n_actions):
                    group_0_selected = torch.where(actions_group_0 == a)[0]
                    group_1_selected = torch.where(actions_group_1 == a)[0]
                    min_length = min(len(group_0_selected), len(group_1_selected))
                    if min_length == 0:
                        continue
                    
                    pred_group_0_mean, pred_group_0_logvar = self.dynamics_model.forward(
                        states_group_0[group_0_selected][:min_length],
                        groups_group_0[group_0_selected][:min_length],
                        actions_group_0[group_0_selected][:min_length],
                    )
                    
                    pred_group_1_mean, pred_group_1_logvar = self.dynamics_model.forward(
                        states_group_1[group_1_selected][:min_length],
                        groups_group_1[group_1_selected][:min_length],
                        actions_group_1[group_1_selected][:min_length],
                    )
                    
                    pred_group_0_std = torch.exp(0.5 * pred_group_0_logvar)
                    pred_group_1_std = torch.exp(0.5 * pred_group_1_logvar)
                    
                    bisimulation_dist = torch.sqrt(
                        (pred_group_0_mean - pred_group_1_mean).pow(2) + 
                        (pred_group_0_std - pred_group_1_std).pow(2)
                    ).mean()
                    
                    loss += bisimulation_dist.item()
                    
            return self.gamma * loss 
        
        def lending_cost_function(positive_credit_changes, negative_credit_changes):
            # Set the environment parameters
            env.unwrapped.set_credit_changes(
                positive_credit_changes, 
                negative_credit_changes,
            )
            
            # Collect rollout
            buffer = self.collect_rollout(env, agent)
            # Train dynamics model
            metrics.update(self.train_dynamics_model(buffer))
            # Compute the bisimulation loss
            loss = compute_bisim_dyn_loss(buffer)
            
            return loss

        def college_cost_function(score_changes_coef):
            # Set the environment parameters
            env.unwrapped.set_score_changes(score_changes_coef)
            
            # Collect rollout
            buffer = self.collect_rollout(env, agent)
            # Train dynamics model
            metrics.update(self.train_dynamics_model(buffer))
            # Compute the bisimulation loss
            loss = compute_bisim_dyn_loss(buffer)
            
            return loss
        
        if global_step > self.start_dyn_opt_step and not self.optimized_mdp_dyn:
            # Optimize the MDP dynamics
            print("Optimizing the MDP dynamics...")
            
            if self.env_name == "lending":
                for i in range(self.dyn_opt_iters):
                    recommendation = self.ng_optimizer.ask()
                    opt_loss = lending_cost_function(**recommendation.kwargs)
                    self.ng_optimizer.tell(recommendation, opt_loss)
                    
                    print(f"Iter: {i}, loss: {opt_loss}")
                
                recommendation = self.ng_optimizer.provide_recommendation()
                self.optimized_mdp_dyn = True
                
                metrics['losses/bisim_dynamics_loss'] = lending_cost_function(**recommendation.kwargs)
                print(f"Optimized MDP dynamics parameters: {recommendation.kwargs}")
                print(f"Final cost function: {metrics['losses/bisim_dynamics_loss']}")
                
                self.positive_credit_changes = recommendation.kwargs["positive_credit_changes"]
                self.negative_credit_changes = recommendation.kwargs["negative_credit_changes"]
            
                return (
                    metrics, 
                    self.positive_credit_changes,
                    self.negative_credit_changes,
                )
            elif self.env_name == "college":
                for i in range(self.dyn_opt_iters):
                    recommendation = self.ng_optimizer.ask()
                    opt_loss = college_cost_function(**recommendation.kwargs)
                    self.ng_optimizer.tell(recommendation, opt_loss)
                    
                    print(f"Iter: {i}, loss: {opt_loss}")
                
                recommendation = self.ng_optimizer.provide_recommendation()
                self.optimized_mdp_dyn = True
                
                metrics['losses/bisim_dynamics_loss'] = college_cost_function(**recommendation.kwargs)
                print(f"Optimized MDP dynamics parameters: {recommendation.kwargs}")
                print(f"Final cost function: {metrics['losses/bisim_dynamics_loss']}")
                
                self.score_changes_coef = recommendation.kwargs["score_changes_coef"]
                
                return (
                    metrics, 
                    self.score_changes_coef,
                )
            else:
                raise ValueError("Unsupported environment")
        
        elif global_step > self.start_dyn_opt_step and self.optimized_mdp_dyn:
            # Only return the MDP dynamics parameters
            if self.env_name == "lending":
                return (
                    metrics,
                    self.positive_credit_changes,
                    self.negative_credit_changes,
                )
            elif self.env_name == "college":
                return (
                    metrics,
                    self.score_changes_coef,
                )
            else:
                raise ValueError("Unsupported environment")
        
        else:
            # Only train the dynamics model 
            buffer = self.collect_rollout(env, agent)
            metrics = self.train_dynamics_model(buffer)
            
            if self.env_name == "lending":
                return metrics, None, None
            elif self.env_name == "college":
                return metrics, None
            else:
                raise ValueError("Unsupported environment")

    def save(self, save_path: str) -> None:
        """
        Saves the Bisimulator model to the specified directory.
        
        Args:
            save_path (str): The directory to save the model.
        """
        os.makedirs(save_path, exist_ok=True)
        torch.save(self.reward.state_dict(), f"{save_path}/reward.pt")
        torch.save(self.dynamics_model.state_dict(), f"{save_path}/dynamics_model.pt")
    
    def load(self, load_path: str) -> None:
        """
        Loads the Bisimulator model from the specified directory.
        
        Args:
            load_path (str): The directory to load the model.
        """
        self.reward.load_state_dict(torch.load(f"{load_path}/reward.pt"))
        self.dynamics_model.load_state_dict(torch.load(f"{load_path}/dynamics_model.pt"))
