import os
import time
from typing import Dict, List, Optional, Iterable, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from gymnasium import Wrapper

from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.policy.sample_batch import SampleBatch
from reward_modeling.replay_buffer import ReplayBuffer


class RewardModel(nn.Module):
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        sequence_lens: int,
        discrete_actions: bool,
        env_name: str,
        unique_id: str,
        rm_id: str,
        lr: float,
        n_epochs: int,
        n_layers: int = 5,
        layer_size: int = 512,
        use_weight_decay: bool = True,
    ):
        super().__init__()
        self.sequence_lens = sequence_lens
        self.action_dim = action_dim
        self.discrete_actions = discrete_actions
        self.env_name = env_name
        self.unique_id = unique_id
        self.rm_id = rm_id

        layers: List[nn.Module] = []
        for i in range(n_layers):
            in_features = obs_dim + action_dim if i == 0 else layer_size
            out_features = 1 if i == n_layers - 1 else layer_size
            layers.append(nn.Linear(in_features, out_features))
            if i < n_layers - 1:
                layers.append(nn.ReLU(inplace=True))
        self.model = nn.Sequential(*layers)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)

        self.use_weight_decay = use_weight_decay
        self.reinitialize_parameters(lr=lr, n_epochs=n_epochs)

    def perturb_weights(self, std: float = 0.05) -> None:
        with torch.no_grad():
            for param in self.parameters():
                param.add_(torch.randn_like(param) * std)

    def reinitialize_parameters(
        self, lr: Optional[float] = None, n_epochs: Optional[int] = None, perturb_std: Optional[float] = None
    ) -> None:
        """
        Fully reinitialize network parameters and (re)create optimizer.
        """
        if lr is None:
            lr = getattr(self, "_init_lr", None)
        if n_epochs is None:
            n_epochs = getattr(self, "_init_epochs", None)
        if lr is None or n_epochs is None:
            raise ValueError("lr and n_epochs must be provided at least once before calling without arguments.")

        self._init_lr = lr
        self._init_epochs = n_epochs

        with torch.no_grad():
            for m in self.modules():
                if isinstance(m, nn.Linear):
                    m.reset_parameters()

        if perturb_std is not None and perturb_std > 0:
            with torch.no_grad():
                for p in self.parameters():
                    p.add_(torch.randn_like(p) * perturb_std)

        wd = 1e-5 if self.use_weight_decay else 0.0
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)
        self.optimizer.zero_grad(set_to_none=True)


    def zero_model_params(
        self, zero_entire_model: bool = False, freeze_output: bool = False, clear_opt_state: bool = True
    ) -> None:
        """
        Make the model's output identically zero for any input by zeroing the final Linear layer
        (or all Linear layers if zero_entire_model=True).
        """
        linear_layers = [m for m in self.modules() if isinstance(m, nn.Linear)]
        if not linear_layers:
            raise RuntimeError("RewardModel has no Linear layers to zero.")

        with torch.no_grad():
            if zero_entire_model:
                for layer in linear_layers:
                    layer.weight.zero_()
                    if layer.bias is not None:
                        layer.bias.zero_()
                last_linear = linear_layers[-1]
            else:
                last_linear = linear_layers[-1]
                last_linear.weight.zero_()
                if last_linear.bias is not None:
                    last_linear.bias.zero_()

        # Keeping freeze/clear disabled to preserve current behavior

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    def save_params(self) -> None:
        os.makedirs(f"active_models/{self.rm_id}", exist_ok=True)
        torch.save(self.state_dict(), f"active_models/{self.rm_id}/reward_model_{self.unique_id}.pth")


    def load_params(self, map_to_cpu: bool = False) -> None:
        if self.unique_id is None:
            raise ValueError("unique_id must be set to load parameters")
        print("Loading reward model parameters from unique id:", self.unique_id)
        fp = self.get_fp()
        state = torch.load(fp, map_location=torch.device("cpu") if map_to_cpu else None)
        self.load_state_dict(state)
        self.train()

    def get_fp(self) -> str:
        if self.unique_id is None:
            raise ValueError("unique_id must be set to load parameters")
        return f"active_models/{self.rm_id}/reward_model_{self.unique_id}.pth"


class RewardModelTrainer:
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        sequence_lens: int,
        discrete_actions: bool,
        env_name: str,
        lr: float = 0.001,
        n_epochs: int = 250,
        n_layers: int = 5,
        layer_size: int = 512,
        use_weight_decay: bool = True,
        unique_id: Optional[str] = None,
        n_reward_models: int = 10,
    ):
        self.sequence_lens = sequence_lens
        self.action_dim = action_dim
        self.discrete_actions = discrete_actions
        self.env_name = env_name
        self.unique_id = unique_id
        self.n_reward_models = n_reward_models
        self.lr = lr

        print("Create rm with unique_id:", self.unique_id)
        self.reward_ensemble: List[RewardModel] = [
            RewardModel(
                obs_dim=obs_dim,
                action_dim=action_dim,
                sequence_lens=sequence_lens,
                discrete_actions=discrete_actions,
                env_name=env_name,
                unique_id=unique_id,
                rm_id=str(i),
                lr=lr,
                n_epochs=n_epochs,
                n_layers=n_layers,
                layer_size=layer_size,
                use_weight_decay=use_weight_decay
            )
            for i in range(self.n_reward_models)
        ]

        self.n_epochs = n_epochs
        self.sigmoid = nn.Sigmoid()

        # Replay buffers
        self.replay_buffer = ReplayBuffer(1_000_000)
        self.candidate_replay_buffer = ReplayBuffer(1_000_000)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.indice2traj: Dict[str, Dict[str, torch.Tensor]] = {}
        self.num_batches_added = 0

        if self.unique_id is None:
            raise ValueError("unique_id must be set to save parameters")
        time.sleep(1)

    def zero_model_params(
        self,
        zero_entire_model: bool = False,
        freeze_output: bool = False,
        clear_opt_state: bool = False,
        model_indices: Optional[Iterable[int]] = None,
    ) -> None:
        """Zero out parameters so selected reward models output 0 for any input."""
        models = self.reward_ensemble if model_indices is None else [self.reward_ensemble[i] for i in model_indices]
        for rm in models:
            rm.zero_model_params(
                zero_entire_model=zero_entire_model,
                freeze_output=freeze_output,
                clear_opt_state=clear_opt_state,
            )

    def save_random_params(self) -> None:
        if self.unique_id is None:
            raise ValueError("unique_id must be set to save parameters")
        rm_i = np.random.randint(0, self.n_reward_models)
        os.makedirs("active_models/random", exist_ok=True)
        torch.save(self.reward_ensemble[rm_i].state_dict(), f"active_models/random/reward_model_{self.unique_id}.pth")

    @staticmethod
    def _create_sample_batch(rewards, actions, obs, new_obs, reward_for_pref, proxy_rewards, index: int) -> Dict[str, torch.Tensor]:
        return {
            SampleBatch.REWARDS: torch.tensor(rewards[index]),
            SampleBatch.ACTIONS: torch.tensor(actions[index]),
            SampleBatch.OBS: torch.tensor(obs[index]),
            "new_obs": torch.tensor(new_obs[index]),
            "reward_for_pref": torch.tensor(reward_for_pref[index]),
            "proxy_rewards": torch.tensor(proxy_rewards[index]),
        }

    def _get_concatenated_obs_action(
        self, obs: torch.Tensor, new_obs: Optional[torch.Tensor], actions: torch.Tensor
    ) -> torch.Tensor:
        if self.discrete_actions:
            encoded_actions = F.one_hot(actions.long(), self.action_dim)
            net_input = torch.cat([obs, new_obs, encoded_actions], dim=1) if new_obs is not None else torch.cat(
                [obs, encoded_actions], dim=1
            )
        else:
            net_input = torch.cat([obs, new_obs, actions], dim=1) if new_obs is not None else torch.cat([obs, actions], dim=1)
        return net_input.to(torch.float32)

    @staticmethod
    def _calculate_discounted_sum_and_diffs(traj1_rews: torch.Tensor, traj2_rews: torch.Tensor, gamma: float = 0.99) -> torch.Tensor:
        discounts1 = gamma ** torch.arange(len(traj1_rews), device=traj1_rews.device)
        discounts2 = gamma ** torch.arange(len(traj2_rews), device=traj2_rews.device)
        return (discounts2 * traj2_rews).sum(dim=0) - (discounts1 * traj1_rews).sum(dim=0)

    def _calculate_true_reward_comparisons(self, traj1: Dict[str, torch.Tensor], traj2: Dict[str, torch.Tensor]) -> torch.Tensor:
        traj1_true_rewards = traj1["reward_for_pref"]
        traj2_true_rewards = traj2["reward_for_pref"]
        rewards_diff = self._calculate_discounted_sum_and_diffs(traj1_true_rewards, traj2_true_rewards)
        probs = torch.sigmoid(rewards_diff)
        return (torch.rand(probs.size(), device=probs.device) < probs).float()

    def _calculate_boltzmann_pred_probs(
        self, traj1: Dict[str, torch.Tensor], traj2: Dict[str, torch.Tensor], rm_i: int, modify_proxy_reward: bool = False
    ) -> Tuple[torch.Tensor, np.ndarray]:
        if "pandemic" in self.env_name and "sas" not in self.env_name:
            net_input1 = self._get_concatenated_obs_action(traj1["obs"].flatten(1).to(self.device), None, traj1["actions"].to(self.device))
            net_input2 = self._get_concatenated_obs_action(traj2["obs"].flatten(1).to(self.device), None, traj2["actions"].to(self.device))
        else:
            net_input1 = self._get_concatenated_obs_action(
                traj1["obs"].flatten(1).to(self.device), traj1["new_obs"].flatten(1).to(self.device), traj1["actions"].to(self.device)
            )
            net_input2 = self._get_concatenated_obs_action(
                traj2["obs"].flatten(1).to(self.device), traj2["new_obs"].flatten(1).to(self.device), traj2["actions"].to(self.device)
            )

        traj1_preds = self.reward_ensemble[rm_i].forward(net_input1).flatten()
        traj2_preds = self.reward_ensemble[rm_i].forward(net_input2).flatten()

        if modify_proxy_reward:
            # Appendix F.2 of https://arxiv.org/pdf/2507.00611 uses tanh on residuals to keep values small
            traj1_preds = torch.tanh(traj1_preds) + traj1["proxy_rewards"].flatten().to(self.device)
            traj2_preds = torch.tanh(traj2_preds) + traj2["proxy_rewards"].flatten().to(self.device)

        preds_diff = self._calculate_discounted_sum_and_diffs(traj1_preds, traj2_preds)
        softmax_probs = torch.sigmoid(preds_diff)
        return softmax_probs.float(), traj1_preds.detach().cpu().numpy()

    @staticmethod
    def _truncate_trajectory(traj: Dict[str, torch.Tensor], max_len: Optional[int]) -> Dict[str, torch.Tensor]:
        if max_len is None:
            return traj
        for key in traj.keys():
            traj[key] = traj[key][:max_len]
        return traj

    @staticmethod
    def split_by_sequence(arr, batch_seq_lens: Iterable[int]) -> List:
        """Slice a flat sequence into variable-length chunks."""
        sequences = []
        start = 0
        for length in batch_seq_lens:
            end = start + length
            sequences.append(arr[start:end])
            start = end
        if start != len(arr):
            raise ValueError(f"sum(batch_seq_lens)={start} must equal len(arr)={len(arr)}")
        return sequences

    @staticmethod
    def third_and_triple_sequences(sequences: List) -> List:
        """Split sequences into three chunks when long; otherwise keep as-is."""
        out: List = []
        longest = max(len(seq) for seq in sequences)
        for seq in sequences:
            if len(seq) >= longest / 3:
                n = len(seq)
                out.append(seq[: n // 3])
                out.append(seq[n // 3 : (2 * n) // 3])
                out.append(seq[(2 * n) // 3 :])
            else:
                out.append(seq)
        return out

    def get_batch_sequences(self, train_batch, batch_seq_lens: Iterable[int]):
        print("batch_seq_lens:", batch_seq_lens)
        actions = torch.from_numpy(convert_to_numpy(train_batch[SampleBatch.ACTIONS]))
        print("actions shape:", actions.shape)

        rewards_sequences = self.split_by_sequence(train_batch[SampleBatch.REWARDS], batch_seq_lens)
        obs_sequences = self.split_by_sequence(train_batch[SampleBatch.OBS], batch_seq_lens)
        new_obs_sequences = self.split_by_sequence(train_batch["new_obs"], batch_seq_lens)
        acs_sequences = self.split_by_sequence(actions, batch_seq_lens)

        print("after splitting by sequence:", len(acs_sequences))

        true_rews_flat = [info.get("true_rew", 0) for info in train_batch["infos"]]
        reward_sequences_for_prefs = self.split_by_sequence(true_rews_flat, batch_seq_lens)

        if sum(true_rews_flat) == 0:
            true_rews_alt_flat = [info.get("true_reward", 0) for info in train_batch["infos"]]
            reward_sequences_for_prefs = self.split_by_sequence(true_rews_alt_flat, batch_seq_lens)

        proxy_reward_seq = self.split_by_sequence(
            [info.get("original_reward", 0) for info in train_batch["infos"]], batch_seq_lens
        )
        modified_reward_seq = self.split_by_sequence(
            [info.get("modified_reward", 0) for info in train_batch["infos"]], batch_seq_lens
        )

        sorted_idx = sorted(range(len(batch_seq_lens)), key=list(batch_seq_lens).__getitem__, reverse=True)

        batch_seq_lens = [list(batch_seq_lens)[i] for i in sorted_idx]
        rewards_sequences = [rewards_sequences[i] for i in sorted_idx]
        obs_sequences = [obs_sequences[i] for i in sorted_idx]
        new_obs_sequences = [new_obs_sequences[i] for i in sorted_idx]
        acs_sequences = [acs_sequences[i] for i in sorted_idx]
        reward_sequences_for_prefs = [reward_sequences_for_prefs[i] for i in sorted_idx]
        proxy_reward_seq = [proxy_reward_seq[i] for i in sorted_idx]
        modified_reward_seq = [modified_reward_seq[i] for i in sorted_idx]

        print("batch_seq_lens:", batch_seq_lens)

        if "glucose" in self.env_name:
            rewards_sequences = self.third_and_triple_sequences(rewards_sequences)
            obs_sequences = self.third_and_triple_sequences(obs_sequences)
            new_obs_sequences = self.third_and_triple_sequences(new_obs_sequences)
            acs_sequences = self.third_and_triple_sequences(acs_sequences)
            reward_sequences_for_prefs = self.third_and_triple_sequences(reward_sequences_for_prefs)
            proxy_reward_seq = self.third_and_triple_sequences(proxy_reward_seq)
            modified_reward_seq = self.third_and_triple_sequences(modified_reward_seq)

        print("after halving and doubling sequences:", len(acs_sequences))
        print("========")
        return (
            rewards_sequences,
            acs_sequences,
            obs_sequences,
            new_obs_sequences,
            reward_sequences_for_prefs,
            proxy_reward_seq,
            batch_seq_lens,
        )

    @staticmethod
    def get_seq_lens(train_batch1) -> np.ndarray:
        eps_ids = train_batch1[SampleBatch.EPS_ID]
        print("len(eps_ids):", len(eps_ids))
        unique_vals, first_idx = np.unique(eps_ids, return_index=True)
        print("len(unique_vals):", len(unique_vals))
        order = np.argsort(first_idx)
        counts = np.bincount(np.searchsorted(unique_vals, eps_ids))[order]
        batch_seq_lens_1 = np.array(counts)
        print("len(batch_seq_lens_1):", len(batch_seq_lens_1))
        print("\n")
        return batch_seq_lens_1

    def add_candidates2replay(
        self,
        train_batch1,
        train_batch2,
        set_num_seqs: bool = False,
        skip_ref_trajs: bool = False,
        only_ref_trajs: bool = False,
    ) -> None:
        batch_seq_lens1 = self.get_seq_lens(train_batch1)
        batch_seq_lens2 = self.get_seq_lens(train_batch2)

        (
            rewards_sequences1,
            acs_sequences1,
            obs_sequences1,
            new_obs_sequences1,
            reward_sequences_for_prefs1,
            proxy_reward_seq1,
            batch_seq_lens1,
        ) = self.get_batch_sequences(train_batch1, batch_seq_lens1)
        (
            rewards_sequences2,
            acs_sequences2,
            obs_sequences2,
            new_obs_sequences2,
            reward_sequences_for_prefs2,
            proxy_reward_seq2,
            batch_seq_lens2,
        ) = self.get_batch_sequences(train_batch2, batch_seq_lens2)

        num_sequences = min(len(rewards_sequences1), len(rewards_sequences2))
        if "tomato" in self.env_name or "traffic" in self.env_name:
            assert num_sequences >= 20
            self.n_prefs_per_epoch = 20 * 20
        elif "glucose" in self.env_name:
            assert num_sequences >= 40
            self.n_prefs_per_epoch = 40 * 40
        else:
            assert num_sequences >= 80
            self.n_prefs_per_epoch = 80 * 80

        if set_num_seqs:
            num_sequences = int(np.sqrt(self.n_prefs_per_epoch))
        else:
            num_sequences = min(num_sequences, 200)

        self.num_batches_added += 1
        trajectory_pairs = [(i, j) for i in range(num_sequences - 1) for j in range(num_sequences - 1)]
        print("# of trajectory pairs 2 add:", len(trajectory_pairs))
        print("num_sequences:", num_sequences)
        print("(len(batch_seq_lens1), len(batch_seq_lens2))", (len(batch_seq_lens1), len(batch_seq_lens2)))
        print("len(rewards_sequences1):", len(rewards_sequences1))
        print("len(rewards_sequences2):", len(rewards_sequences2))

        max_len = None
        over_opt_max_len = None
        for i_idx, j_idx in trajectory_pairs:
            if "glucose" in self.env_name:
                max_len = min(len(rewards_sequences1[i_idx]), len(rewards_sequences2[j_idx]))

            traj1_key = f"{self.num_batches_added}_{i_idx}_t0_max_len={max_len}"
            if traj1_key in self.indice2traj:
                traj1 = self.indice2traj[traj1_key]
            else:
                traj1 = self._create_sample_batch(
                    rewards_sequences1, acs_sequences1, obs_sequences1, new_obs_sequences1, reward_sequences_for_prefs1, proxy_reward_seq1, i_idx
                )
                traj1 = self._truncate_trajectory(traj1, max_len)
                self.indice2traj[traj1_key] = traj1

            traj2_key = f"{self.num_batches_added}_{j_idx}_t1_max_len={max_len}"
            if not skip_ref_trajs:
                if traj2_key in self.indice2traj:
                    traj2 = self.indice2traj[traj2_key]
                else:
                    traj2 = self._create_sample_batch(
                        rewards_sequences2, acs_sequences2, obs_sequences2, new_obs_sequences2, reward_sequences_for_prefs2, proxy_reward_seq2, j_idx
                    )
                    traj2 = self._truncate_trajectory(traj2, max_len)
                    self.indice2traj[traj2_key] = traj2

                true_reward_label = self._calculate_true_reward_comparisons(traj1, traj2).to(self.device)
                self.candidate_replay_buffer.push(traj1_key, traj2_key, true_reward_label)

            if i_idx != j_idx and not set_num_seqs and not only_ref_trajs:
                if "glucose" in self.env_name:
                    over_opt_max_len = min(len(rewards_sequences1[i_idx]), len(rewards_sequences1[j_idx]))

                traj1_over_opt_key = f"{self.num_batches_added}_{i_idx}_t0_max_len={over_opt_max_len}"
                if traj1_over_opt_key in self.indice2traj:
                    traj1_over_opt = self.indice2traj[traj1_over_opt_key]
                else:
                    traj1_over_opt = self._create_sample_batch(
                        rewards_sequences1, acs_sequences1, obs_sequences1, new_obs_sequences1, reward_sequences_for_prefs1, proxy_reward_seq1, i_idx
                    )
                    traj1_over_opt = self._truncate_trajectory(traj1_over_opt, over_opt_max_len)
                    self.indice2traj[traj1_over_opt_key] = traj1_over_opt

                traj2_over_opt_key = f"{self.num_batches_added}_{j_idx}_t0_max_len={over_opt_max_len}"
                if traj2_over_opt_key in self.indice2traj:
                    traj2_over_opt = self.indice2traj[traj2_over_opt_key]
                else:
                    traj2_over_opt = self._create_sample_batch(
                        rewards_sequences1, acs_sequences1, obs_sequences1, new_obs_sequences1, reward_sequences_for_prefs1, proxy_reward_seq1, j_idx
                    )
                    traj2_over_opt = self._truncate_trajectory(traj2_over_opt, over_opt_max_len)
                    self.indice2traj[traj2_over_opt_key] = traj2_over_opt

                true_reward_label = self._calculate_true_reward_comparisons(traj1_over_opt, traj2_over_opt).to(self.device)
                self.candidate_replay_buffer.push(traj1_over_opt_key, traj2_over_opt_key, true_reward_label)

    def check_stopping_condition(self, over_opt_batch, ref_batch):
        return 0  # preserved behavior (early return)

    def re_initialize_ensemble(self) -> None:
        for rm_i in range(self.n_reward_models):
            self.reward_ensemble[rm_i].reinitialize_parameters(lr=self.lr, n_epochs=self.n_epochs)

    def update_params(
        self,
        train_batch1,
        train_batch2,
        iteration: int,
        debug_mode: bool = False,
        use_minibatch: bool = False,
        force_n_epochs: Optional[int] = None,
        use_all_pairs: bool = False,
        skip_ref_trajs: bool = False,
        only_ref_trajs: bool = False,
        modify_proxy_reward: bool = False,
    ) -> None:
        if not debug_mode:
            self.add_candidates2replay(
                train_batch1, train_batch2, set_num_seqs=use_all_pairs, skip_ref_trajs=skip_ref_trajs, only_ref_trajs=only_ref_trajs
            )

        if use_all_pairs:
            self.replay_buffer = self.candidate_replay_buffer
        else:
            with torch.no_grad():
                item_pref_vars = []
                for item in self.candidate_replay_buffer.buffer:
                    if item is None:
                        continue
                    traj1 = self.indice2traj[item["traj1"]]
                    traj2 = self.indice2traj[item["traj2"]]
                    ensemble_pred_probs = []
                    for rm_i in range(self.n_reward_models):
                        pred_probs, _ = self._calculate_boltzmann_pred_probs(traj1, traj2, rm_i, modify_proxy_reward=modify_proxy_reward)
                        ensemble_pred_probs.append(pred_probs.detach().cpu().item())
                    item_pref_vars.append(np.var(ensemble_pred_probs))

                # print (item_pref_vars)
                print("mean pref. prob var:", np.mean(item_pref_vars))
                print("min pref. prob var:", np.min(item_pref_vars))
                print("max pref. prob var:", np.max(item_pref_vars))
                print("------")

                top_n_indices = np.argsort(item_pref_vars)[-self.n_prefs_per_epoch:]
                for idx in top_n_indices:
                    item = self.candidate_replay_buffer.buffer[idx]
                    if item is None:
                        continue
                    self.replay_buffer.push(item["traj1"], item["traj2"], item["true_label"])

        self.re_initialize_ensemble()

        all_losses: List[float] = []
        for _ in range(self.n_epochs):
            BATCH_SIZE = 32
            buffer_items = [item for item in self.replay_buffer.buffer if item is not None]
            np.random.shuffle(buffer_items)

            reward_model_loss = 0.0
            for i in range(0, len(buffer_items), BATCH_SIZE):
                batch = buffer_items[i : i + BATCH_SIZE]

                for rm_i in range(self.n_reward_models):
                    batch_loss = 0.0
                    for item in batch:
                        traj1 = self.indice2traj[item["traj1"]]
                        traj2 = self.indice2traj[item["traj2"]]
                        true_label = item["true_label"]

                        predicted_reward_probs, _ = self._calculate_boltzmann_pred_probs(
                            traj1, traj2, rm_i, modify_proxy_reward=modify_proxy_reward
                        )
                        predicted_reward_probs = predicted_reward_probs.to(self.device)

                        item_loss = F.binary_cross_entropy(predicted_reward_probs, true_label)
                        batch_loss += item_loss

                    loss = batch_loss / max(1, len(batch))
                    self.reward_ensemble[rm_i].optimizer.zero_grad()
                    loss.backward()
                    self.reward_ensemble[rm_i].optimizer.step()

                    reward_model_loss += (batch_loss.detach().cpu() / max(1, len(batch)))

            reward_model_loss /= max(1, self.n_reward_models)
            all_losses.append(reward_model_loss.item())
            print("reward_model_loss:", reward_model_loss)

        if self.unique_id is None:
            raise ValueError("unique_id must be set to save parameters")
        rm_i = np.random.randint(0, self.n_reward_models)
        os.makedirs("active_models/random", exist_ok=True)
        torch.save(self.reward_ensemble[rm_i].state_dict(), f"active_models/random/reward_model_{self.unique_id}.pth")

        os.makedirs("active_models", exist_ok=True)
        with open(f"active_models/replay_buffer_{self.unique_id}.pkl", "wb") as f:
            torch.save(self.replay_buffer, f)
            


class RewardWrapper(Wrapper):
    def __init__(self, env, reward_model: str = "custom", unique_id: Optional[str] = None, modify_proxy_reward: bool = False):
        super().__init__(env)
        self.reward_model = reward_model
        self.modify_proxy_reward = modify_proxy_reward

        if reward_model == "custom_pandemic_sas":
            self.reward_net = RewardModel(
                obs_dim=2 * 24 * 13,
                action_dim=3,
                sequence_lens=193,
                discrete_actions=True,
                env_name="pandemic_sas",
                unique_id=unique_id,
                n_epochs=200,
                lr=1e-4,
                n_layers=5,
                layer_size=256,
                use_weight_decay=False,
                rm_id="random",
            )
        elif reward_model == "custom_tomato":
            self.reward_net = RewardModel(
                obs_dim=2 * 36,
                action_dim=4,
                sequence_lens=100,
                discrete_actions=True,
                env_name="tomato",
                unique_id=unique_id,
                n_epochs=200,
                lr=1e-4,
                n_layers=5,
                layer_size=512,
                use_weight_decay=True,
                rm_id="random",
            )
        elif reward_model == "custom_traffic_sas":
            self.reward_net = RewardModel(
                obs_dim=2 * 50,
                action_dim=10,
                sequence_lens=4000,
                discrete_actions=False,
                env_name="traffic_sas",
                unique_id=unique_id,
                n_epochs=50,
                lr=1e-3,
                n_layers=5,
                layer_size=512,
                use_weight_decay=False,
                rm_id="random",
            )
        
        elif reward_model == "custom_glucose_sas":
            self.reward_net = RewardModel(
                obs_dim=2 * 48 * 2,
                action_dim=1,
                sequence_lens=5760,
                discrete_actions=False,
                env_name="glucose_sas",
                unique_id=unique_id,
                n_epochs=200,
                lr=1e-4,
                n_layers=5,
                layer_size=512,
                use_weight_decay=True,
                rm_id="random",
            )
        
        else:
            raise NotImplementedError(f"Reward model {reward_model} not implemented")

        if self.reward_net is not None:
            self.reward_net.load_params(map_to_cpu=True)
            self.reward_net.eval()

        if "custom" in self.reward_model:
            self.timestamp = os.path.getmtime(self.reward_net.get_fp())

    @staticmethod
    def _one_hot_encode(num: int, n_classes: int) -> np.ndarray:
        res = np.zeros(n_classes)
        res[num] = 1
        return res

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.last_obs = obs
        return obs, info

    def _get_concatenated_obs_action(self, obs, new_obs, actions):
        if self.reward_net.discrete_actions:
            encoded_actions = F.one_hot(actions.long(), self.reward_net.action_dim)
            net_input = torch.cat([obs, new_obs, encoded_actions], dim=1) if new_obs is not None else torch.cat(
                [obs, encoded_actions], dim=1
            )
        else:
            net_input = torch.cat([obs, new_obs, actions], dim=1) if new_obs is not None else torch.cat([obs, actions], dim=1)
        return net_input.to(torch.float32)

    def step(self, action):
        obs, original_reward, terminated, truncated, info = self.env.step(action)

        reward = original_reward
        
        if "custom" in self.reward_model and self.reward_net is not None:
            # hot-reload if on-disk parameters changed
            if os.path.getmtime(self.reward_net.get_fp()) != self.timestamp:
                self.reward_net.load_params(map_to_cpu=True)
                self.timestamp = os.path.getmtime(self.reward_net.get_fp())
                print("Reloading reward model parameters")

            obs_in = obs
            last_obs_in = self.last_obs
            if isinstance(obs, dict):
                obs_in = np.concatenate((self._one_hot_encode(obs["agent"], 26), obs["tomatoes"]))
                last_obs_in = np.concatenate(
                    (self._one_hot_encode(last_obs_in["agent"], 26), last_obs_in["tomatoes"])
                )

            obs_tensor = torch.from_numpy(np.array(obs_in)).float().unsqueeze(0)
            action_tensor = torch.tensor([action]).float()
            if "pandemic" in self.reward_model and "sas" not in self.reward_model:
                last_obs_tensor = None
            else:
                last_obs_tensor = torch.from_numpy(np.array(last_obs_in)).float().unsqueeze(0).flatten(1)

            net_input = self._get_concatenated_obs_action(
                last_obs_tensor, obs_tensor.flatten(1), action_tensor
            ).to(self.reward_net.device)

            with torch.no_grad():
                rm_val = self.reward_net(net_input).squeeze().item()
            reward = rm_val
            if self.modify_proxy_reward:
                reward = np.tanh(reward).item() + original_reward
            info["modified_reward"] = reward
        else:
            reward = original_reward

        info["original_reward"] = original_reward
        self.last_obs = obs
        return obs, reward, terminated, truncated, info
