import os
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from gymnasium import Wrapper
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.policy.sample_batch import SampleBatch

from reward_modeling.replay_buffer import ReplayBuffer


class RewardModel(nn.Module):
    def __init__(
        self,
        obs_dim: int,
        action_dim: int,
        sequence_lens: int,
        discrete_actions: bool,
        env_name: str,
        lr: float = 0.001,
        n_epochs: int = 250,
        n_layers: int = 5,
        layer_size: int = 512,
        use_weight_decay: bool = True,
        unique_id: Optional[str] = None,
    ):
        super().__init__()
        self.sequence_lens = sequence_lens
        self.action_dim = action_dim
        self.obs_dim = obs_dim
        self.discrete_actions = discrete_actions
        self.env_name = env_name

        # MLP head
        layers: List[nn.Module] = []
        for i in range(n_layers):
            in_features = obs_dim + action_dim if i == 0 else layer_size
            out_features = 1 if i == n_layers - 1 else layer_size
            layers.append(nn.Linear(in_features, out_features))
            if i < n_layers - 1:
                layers.append(nn.ReLU(inplace=True))
        self.model = nn.Sequential(*layers)

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.to(self.device)
        self.train()
        self.unique_id = unique_id

        if self.unique_id is None:
            raise ValueError("unique_id must be set to save/load parameters")

        # Optimizer
        self.use_weight_decay = use_weight_decay
        self.lr = lr
        wd = 1e-5 if use_weight_decay else 0.0
        self.optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=wd)

        self.n_epochs = n_epochs

        # Replay buffer and cache
        self.replay_buffer = ReplayBuffer()
        self.indice2traj: Dict[str, Dict[str, torch.Tensor]] = {}
        self.num_batches_added = 0

        print("Create rm with unique_id:", self.unique_id)

    # -------------------- I/O --------------------

    def save_params(self) -> None:
        print("Saving reward model parameters...")
        torch.save(self.state_dict(), f"active_models/reward_model_{self.unique_id}.pth")

    def load_params(self, map_to_cpu: bool = False) -> None:
        fp = self.get_fp()
        state = torch.load(fp, map_location=torch.device("cpu") if map_to_cpu else None)
        self.load_state_dict(state)
        self.train()

    def get_fp(self) -> str:
        return f"active_models/reward_model_{self.unique_id}.pth"

    # -------------------- Init/Reset --------------------

    def zero_model_params(self) -> None:
        # Xavier init for all Linear layers, then zero the last one
        linear_layers = [layer for layer in self.modules() if isinstance(layer, nn.Linear)]
        for layer in linear_layers:
            nn.init.xavier_uniform_(layer.weight)
            if layer.bias is not None:
                nn.init.zeros_(layer.bias)

        last_layer = linear_layers[-1]
        nn.init.zeros_(last_layer.weight)
        if last_layer.bias is not None:
            nn.init.zeros_(last_layer.bias)

    def reinitialize_model(self) -> None:
        wd = 1e-5 if self.use_weight_decay else 0.0
        self.optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=wd)

        assert len(list(self.modules())) > 0, "Model must have layers to reinitialize"
        for layer in self.modules():
            if isinstance(layer, (nn.Linear, nn.Conv2d, nn.LayerNorm)):
                layer.reset_parameters()
            elif isinstance(layer, (nn.BatchNorm1d, nn.BatchNorm2d)):
                layer.reset_running_stats()
                layer.reset_parameters()

    # -------------------- Forward --------------------

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    # -------------------- Helpers --------------------

    def _create_sample_batch(
        self, rewards, actions, obs, new_obs, reward_for_pref, proxy_rewards, index: int
    ) -> Dict[str, torch.Tensor]:
        return {
            SampleBatch.REWARDS: torch.tensor(rewards[index]),
            SampleBatch.ACTIONS: torch.tensor(actions[index]),
            SampleBatch.OBS: torch.tensor(obs[index]),
            "new_obs": torch.tensor(new_obs[index]),
            "reward_for_pref": torch.tensor(reward_for_pref[index]),
            "proxy_rewards": torch.tensor(proxy_rewards[index]),
        }

    def _get_concatenated_obs_action(
        self, obs: torch.Tensor, new_obs: torch.Tensor, actions: torch.Tensor
    ) -> torch.Tensor:
        if self.discrete_actions:
            encoded_actions = F.one_hot(actions.long(), self.action_dim)
            net_input = torch.cat([obs, new_obs, encoded_actions], dim=1)
        else:
            net_input = torch.cat([obs, new_obs, actions], dim=1)
        return net_input.to(torch.float32)

    @staticmethod
    def _calculate_discounted_sum_and_diffs(
        traj1_rews: torch.Tensor, traj2_rews: torch.Tensor, gamma: float = 0.99
    ) -> torch.Tensor:
        discounts1 = gamma ** torch.arange(len(traj1_rews), device=traj1_rews.device)
        discounts2 = gamma ** torch.arange(len(traj2_rews), device=traj2_rews.device)
        # Returns sum(discounted traj2) - sum(discounted traj1)
        return (discounts2 * traj2_rews).sum(dim=0) - (discounts1 * traj1_rews).sum(dim=0)

    @staticmethod
    def _truncate_trajectory(traj: Dict[str, torch.Tensor], max_len: Optional[int]):
        if max_len is None:
            return traj
        for key in traj.keys():
            traj[key] = traj[key][:max_len]
        return traj

    def _calculate_true_reward_comparisons(self, traj1, traj2) -> torch.Tensor:
        traj1_true = traj1["reward_for_pref"]
        traj2_true = traj2["reward_for_pref"]
        diff = self._calculate_discounted_sum_and_diffs(traj1_true, traj2_true)
        probs = torch.sigmoid(diff)
        return (torch.rand(probs.size(), device=probs.device) < probs).float()

    def _calculate_pred_rewards(self, traj1, traj2):
        net_input1 = self._get_concatenated_obs_action(
            traj1["obs"].flatten(1).to(self.device),
            traj1["new_obs"].flatten(1).to(self.device),
            traj1["actions"].to(self.device),
        )
        net_input2 = self._get_concatenated_obs_action(
            traj2["obs"].flatten(1).to(self.device),
            traj2["new_obs"].flatten(1).to(self.device),
            traj2["actions"].to(self.device),
        )

        traj1_preds = self.forward(net_input1).flatten()
        traj2_preds = self.forward(net_input2).flatten()

        # add original proxy reward to the predicted reward
        comb1 = traj1_preds + traj1["proxy_rewards"].flatten().to(self.device)
        comb2 = traj2_preds + traj2["proxy_rewards"].flatten().to(self.device)
        return comb1, comb2

    def _calculate_boltzmann_pred_probs(self, traj1, traj2):
        comb1, comb2 = self._calculate_pred_rewards(traj1, traj2)
        preds_diff = self._calculate_discounted_sum_and_diffs(comb1, comb2)
        softmax_probs = torch.sigmoid(preds_diff)  # P(traj2 preferred over traj1)
        # Also return the *head* predictions (before adding proxy) for regularization
        net_input1 = self._get_concatenated_obs_action(
            traj1["obs"].flatten(1).to(self.device),
            traj1["new_obs"].flatten(1).to(self.device),
            traj1["actions"].to(self.device),
        )
        net_input2 = self._get_concatenated_obs_action(
            traj2["obs"].flatten(1).to(self.device),
            traj2["new_obs"].flatten(1).to(self.device),
            traj2["actions"].to(self.device),
        )
        traj1_head = self.forward(net_input1).flatten()
        traj2_head = self.forward(net_input2).flatten()
        return softmax_probs.float(), traj1_head.float(), traj2_head.float()

    # -------------------- Sequence utils --------------------

    @staticmethod
    def split_by_sequence(arr, batch_seq_lens: np.ndarray) -> List:
        """Slice a flat sequence into variable-length chunks."""
        sequences = []
        start = 0
        for length in batch_seq_lens:
            end = start + length
            sequences.append(arr[start:end])
            start = end
        if start != len(arr):
            raise ValueError(f"sum(batch_seq_lens)={start} must equal len(arr)={len(arr)}")
        return sequences

    @staticmethod
    def third_and_triple_sequences(sequences: List) -> List:
        """Split sequences into thirds if they are long; helps avoid OOM."""
        tripled: List = []
        longest = max(len(seq) for seq in sequences)
        for seq in sequences:
            if len(seq) >= longest / 3:
                n = len(seq)
                tripled.append(seq[: n // 3])
                tripled.append(seq[n // 3 : (2 * n) // 3])
                tripled.append(seq[(2 * n) // 3 :])
            else:
                tripled.append(seq)
        return tripled

    def get_batch_sequences(self, train_batch, batch_seq_lens: np.ndarray):
        actions = train_batch[SampleBatch.ACTIONS]
        actions = torch.from_numpy(convert_to_numpy(actions))

        rewards_seqs = self.split_by_sequence(train_batch[SampleBatch.REWARDS], batch_seq_lens)
        obs_seqs = self.split_by_sequence(train_batch[SampleBatch.OBS], batch_seq_lens)
        new_obs_seqs = self.split_by_sequence(train_batch["new_obs"], batch_seq_lens)
        acs_seqs = self.split_by_sequence(actions, batch_seq_lens)

        # preference-related reward sequences
        true_rews_flat = [info.get("true_rew", 0) for info in train_batch["infos"]]
        pref_rew_seqs = self.split_by_sequence(true_rews_flat, batch_seq_lens)

        # handle tomato-env mis-naming
        if sum(true_rews_flat) == 0:
            true_rews_alt_flat = [info.get("true_reward", 0) for info in train_batch["infos"]]
            pref_rew_seqs = self.split_by_sequence(true_rews_alt_flat, batch_seq_lens)

        proxy_rew_seqs = self.split_by_sequence(
            [info.get("original_reward", 0) for info in train_batch["infos"]], batch_seq_lens
        )
        modified_rew_seqs = self.split_by_sequence(
            [info.get("modified_reward", 0) for info in train_batch["infos"]], batch_seq_lens
        )

        # Sort by sequence length (descending)
        sorted_idx = sorted(range(len(batch_seq_lens)), key=batch_seq_lens.__getitem__, reverse=True)

        def reorder(lst):  # small helper
            return [lst[i] for i in sorted_idx]

        batch_seq_lens = [batch_seq_lens[i] for i in sorted_idx]
        rewards_seqs = reorder(rewards_seqs)
        obs_seqs = reorder(obs_seqs)
        new_obs_seqs = reorder(new_obs_seqs)
        acs_seqs = reorder(acs_seqs)
        pref_rew_seqs = reorder(pref_rew_seqs)
        proxy_rew_seqs = reorder(proxy_rew_seqs)
        modified_rew_seqs = reorder(modified_rew_seqs)

        if "glucose" in self.env_name:
            # Reduce memory pressure by splitting sequences into thirds
            rewards_seqs = self.third_and_triple_sequences(rewards_seqs)
            obs_seqs = self.third_and_triple_sequences(obs_seqs)
            new_obs_seqs = self.third_and_triple_sequences(new_obs_seqs)
            acs_seqs = self.third_and_triple_sequences(acs_seqs)
            pref_rew_seqs = self.third_and_triple_sequences(pref_rew_seqs)
            proxy_rew_seqs = self.third_and_triple_sequences(proxy_rew_seqs)
            modified_rew_seqs = self.third_and_triple_sequences(modified_rew_seqs)
            batch_seq_lens = [len(seq) for seq in rewards_seqs]

        return (
            rewards_seqs,
            acs_seqs,
            obs_seqs,
            new_obs_seqs,
            pref_rew_seqs,
            proxy_rew_seqs,
            batch_seq_lens,
        )

    @staticmethod
    def get_seq_lens(train_batch) -> np.ndarray:
        eps_ids = train_batch[SampleBatch.EPS_ID]
        unique_vals, first_idx = np.unique(eps_ids, return_index=True)
        order = np.argsort(first_idx)
        # counts per episode in original order
        counts = np.bincount(np.searchsorted(unique_vals, eps_ids))[order]
        return np.array(counts)

    # -------------------- Data ingest --------------------

    def add2replay(self, train_batch1, train_batch2) -> None:
        assert len(train_batch1) == len(
            train_batch2
        ), "Both batches should contain the same number of steps for clean pairing."

        bsl1 = self.get_seq_lens(train_batch1)
        bsl2 = self.get_seq_lens(train_batch2)

        (
            rewards1,
            acs1,
            obs1,
            new_obs1,
            pref1,
            proxy1,
            bsl1,
        ) = self.get_batch_sequences(train_batch1, bsl1)

        (
            rewards2,
            acs2,
            obs2,
            new_obs2,
            pref2,
            proxy2,
            bsl2,
        ) = self.get_batch_sequences(train_batch2, bsl2)

        num_sequences = min(len(bsl1), len(bsl2))
        if "tomato" in self.env_name or "traffic" in self.env_name:
            assert num_sequences >= 20
            num_sequences = 20
        elif "glucose" in self.env_name:
            assert num_sequences >= 40
            num_sequences = 40
        else:
            assert num_sequences >= 80
            num_sequences = 80

        self.num_batches_added += 1
        trajectory_pairs = [(i, j) for i in range(num_sequences - 1) for j in range(num_sequences - 1)]
        print("# of trajectory pairs to add:", len(trajectory_pairs))

        max_len = None
        for i_idx, j_idx in trajectory_pairs:
            if "glucose" in self.env_name:
                max_len = min(len(rewards1[i_idx]), len(rewards2[j_idx]))

            t1_key = f"{self.num_batches_added}_{i_idx}_t0_max_len={max_len}"
            if t1_key in self.indice2traj:
                traj1 = self.indice2traj[t1_key]
            else:
                traj1 = self._create_sample_batch(rewards1, acs1, obs1, new_obs1, pref1, proxy1, i_idx)
                traj1 = self._truncate_trajectory(traj1, max_len)
                self.indice2traj[t1_key] = traj1

            t2_key = f"{self.num_batches_added}_{j_idx}_t1_max_len={max_len}"
            if t2_key in self.indice2traj:
                traj2 = self.indice2traj[t2_key]
            else:
                traj2 = self._create_sample_batch(rewards2, acs2, obs2, new_obs2, pref2, proxy2, j_idx)
                traj2 = self._truncate_trajectory(traj2, max_len)
                self.indice2traj[t2_key] = traj2

            true_reward_label = self._calculate_true_reward_comparisons(traj1, traj2).to(self.device)
            self.replay_buffer.push(t1_key, t2_key, true_reward_label)

    # -------------------- Training --------------------

    def update_params(
        self,
        train_batch1,
        train_batch2,
        iteration: int,
        debug_mode: bool = False,
        use_minibatch: bool = False,
        push2zero: float = 10.0,
        use_regularization: bool = True,
    ) -> None:
        """Train from pairs in the replay buffer, optionally adding the new batches first."""
        self.reinitialize_model()
        self.train()

        if not debug_mode:
            self.add2replay(train_batch1, train_batch2)

        for _ in range(self.n_epochs):
            if use_minibatch:
                BATCH_SIZE = 32
                buffer_items = [item for item in self.replay_buffer.buffer if item is not None]
                np.random.shuffle(buffer_items)

                for i in range(0, len(buffer_items), BATCH_SIZE):
                    batch = buffer_items[i : i + BATCH_SIZE]
                    batch_pref_loss = torch.tensor(0.0, device=self.device)
                    batch_reg_loss = torch.tensor(0.0, device=self.device)
                    batch_agreements = 0

                    for item in batch:
                        traj1 = self.indice2traj[item["traj1"]]
                        traj2 = self.indice2traj[item["traj2"]]
                        true_label = item["true_label"].to(self.device)

                        pred_pref_probs, traj1_head, traj2_head = self._calculate_boltzmann_pred_probs(traj1, traj2)
                        pred_pref_probs = pred_pref_probs.to(self.device)

                        det_proxy_label = np.argmax(
                            [
                                torch.sum(traj1["proxy_rewards"], dim=0).item(),
                                torch.sum(traj2["proxy_rewards"], dim=0).item(),
                            ]
                        )

                        # preference loss
                        batch_pref_loss += F.binary_cross_entropy(pred_pref_probs, true_label)

                        if use_regularization:
                            # regularize towards zero when proxy agrees with true preference
                            if (det_proxy_label == 0 and true_label == 0) or (det_proxy_label == 1 and true_label == 1):
                                target1 = torch.zeros_like(traj1_head)
                                target2 = torch.zeros_like(traj2_head)
                                batch_reg_loss += F.mse_loss(traj1_head, target1) + F.mse_loss(traj2_head, target2)
                                batch_agreements += 1
                            else:
                                # lower reward of the dispreferred trajectory
                                if true_label == 0:
                                    batch_reg_loss += F.mse_loss(traj1_head, torch.zeros_like(traj1_head))
                                else:
                                    batch_reg_loss += F.mse_loss(traj2_head, torch.zeros_like(traj2_head))

                    loss = batch_pref_loss / max(1, len(batch))
                    if batch_agreements > 0:
                        loss += push2zero * (batch_reg_loss / batch_agreements)

                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
            else:
                pref_loss = torch.tensor(0.0, device=self.device)
                reg_loss = torch.tensor(0.0, device=self.device)
                n_agreements = 0

                for item in self.replay_buffer.buffer:
                    if item is None:
                        continue
                    traj1 = self.indice2traj[item["traj1"]]
                    traj2 = self.indice2traj[item["traj2"]]
                    true_label = item["true_label"].to(self.device)

                    pred_pref_probs, traj1_head, traj2_head = self._calculate_boltzmann_pred_probs(traj1, traj2)
                    pred_pref_probs = pred_pref_probs.to(self.device)

                    det_proxy_label = np.argmax(
                        [
                            torch.sum(traj1["proxy_rewards"], dim=0).item(),
                            torch.sum(traj2["proxy_rewards"], dim=0).item(),
                        ]
                    )

                    pref_loss += F.binary_cross_entropy(pred_pref_probs, true_label)

                    if use_regularization:
                        if (det_proxy_label == 0 and true_label == 0) or (det_proxy_label == 1 and true_label == 1):
                            zero = torch.zeros_like(traj1_head)
                            reg_loss += F.mse_loss(traj1_head, zero) + F.mse_loss(traj2_head, zero)
                            n_agreements += 1
                        else:
                            zero1 = torch.zeros_like(traj1_head)
                            zero2 = torch.zeros_like(traj2_head)
                            reg_loss += F.mse_loss(traj1_head, zero1) if true_label == 0 else F.mse_loss(traj2_head, zero2)

                loss = pref_loss / max(1, len(self.replay_buffer.buffer))
                if n_agreements > 0:
                    loss += push2zero * (reg_loss / n_agreements)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        # Save model
        self.save_params()


class RewardWrapper(Wrapper):
    def __init__(self, env, reward_model: str = "custom", unique_id: Optional[str] = None, modify_proxy_reward: bool = True):
        super().__init__(env)
        self.reward_model = reward_model

        if not modify_proxy_reward:
            raise ValueError("modify_proxy_reward must be True for this wrapper; you are doing something wrong.")

        if reward_model == "custom_pandemic_sas":
            self.reward_net = RewardModel(
                obs_dim=2 * 24 * 13,
                action_dim=3,
                sequence_lens=193,
                discrete_actions=True,
                env_name="pandemic_sas",
                unique_id=unique_id,
                n_epochs=200,
                lr=0.0001,
                n_layers=5,
                layer_size=512,
                use_weight_decay=True,
            )
        elif reward_model == "custom_tomato":
            self.reward_net = RewardModel(
                obs_dim=2 * 36,
                action_dim=4,
                sequence_lens=100,
                discrete_actions=True,
                env_name="tomato",
                unique_id=unique_id,
                n_epochs=200,
                lr=0.0001,
                n_layers=5,
                layer_size=512,
                use_weight_decay=True,
            )
        elif reward_model == "custom_traffic_sas":
            self.reward_net = RewardModel(
                obs_dim=2 * 50,
                action_dim=10,
                sequence_lens=4000,
                discrete_actions=False,
                env_name="traffic_sas",
                unique_id=unique_id,
                n_epochs=50,
                lr=0.0001,
                n_layers=3,
                layer_size=256,
                use_weight_decay=True,
            )
        elif reward_model == "custom_glucose_sas":
            self.reward_net = RewardModel(
                obs_dim=2 * 48 * 2,
                action_dim=1,
                sequence_lens=5760,
                discrete_actions=False,
                env_name="glucose_sas",
                unique_id=unique_id,
                n_epochs=200,
                lr=0.0001,
                n_layers=5,
                layer_size=512,
                use_weight_decay=False,
            )
        else:
            self.reward_net = None

        if self.reward_net is not None:
            self.reward_net.load_params(map_to_cpu=True)
            self.reward_net.eval()
            self.timestamp = os.path.getmtime(self.reward_net.get_fp())

    @staticmethod
    def one_hot_encode(num: int, n_classes: int) -> np.ndarray:
        res = np.zeros(n_classes)
        res[num] = 1
        return res

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.last_obs = obs
        return obs, info

    def _get_concatenated_obs_action(self, obs, new_obs, actions):
        if self.reward_net.discrete_actions:
            encoded_actions = F.one_hot(actions.long(), self.reward_net.action_dim)
            net_input = torch.cat([obs, new_obs, encoded_actions], dim=1)
        else:
            net_input = torch.cat([obs, new_obs, actions], dim=1)
        return net_input.to(torch.float32)

    def step(self, action):
        obs, original_reward, terminated, truncated, info = self.env.step(action)

        reward = original_reward
        if self.reward_net is not None and "custom" in self.reward_model:
            # hot-reload if the on-disk parameters changed
            if os.path.getmtime(self.reward_net.get_fp()) != self.timestamp:
                self.reward_net.load_params(map_to_cpu=True)
                self.timestamp = os.path.getmtime(self.reward_net.get_fp())
                print("Reloading reward model parameters")


            obs_in = obs
            last_obs_in = self.last_obs

            # Tomato dict obs handling
            if isinstance(obs, dict):
                # 26 is the number of agents in the tomatoes env
                obs_in = np.concatenate((self.one_hot_encode(obs["agent"], 26), obs["tomatoes"]))
                last_obs_in = np.concatenate((self.one_hot_encode(last_obs_in["agent"], 26), last_obs_in["tomatoes"]))

            obs_tensor = torch.from_numpy(np.array(obs_in)).float().unsqueeze(0)
            last_obs_tensor = torch.from_numpy(np.array(last_obs_in)).float().unsqueeze(0).flatten(1)
            action_tensor = torch.tensor([action]).float()

            net_input = self._get_concatenated_obs_action(
                last_obs_tensor, obs_tensor.flatten(1), action_tensor
            ).to(self.reward_net.device)

            with torch.no_grad():
                rm_val = self.reward_net(net_input).squeeze().item()
            reward = rm_val + original_reward
            info["modified_reward"] = reward

        info["original_reward"] = original_reward
        self.last_obs = obs
        return obs, reward, terminated, truncated, info
