# Adapted from https://raw.githubusercontent.com/wadx2019/Neural-Bandit/refs/heads/main/neuralucb.py
# Original work Copyright (c) 2022 Shutong Ding
# Modifications Copyright (c) 2025 King.com Ltd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

from models.mlp import MLP


class ReplayBuffer:
    def __init__(self, d, capacity):
        self.buffer = {'context': np.zeros((capacity, d)), 'reward': np.zeros((capacity,1))}
        self.capacity = capacity
        self.size = 0
        self.pointer = 0


    def add(self, context, reward):
        self.buffer['context'][self.pointer] = context
        self.buffer['reward'][self.pointer] = reward
        self.size = min(self.size+1, self.capacity)
        self.pointer = (self.pointer+1)%self.capacity

    def sample(self, n):
        idx = np.random.randint(0,self.size,size=n)
        return self.buffer['context'][idx], self.buffer['reward'][idx]

    def permute(self):
        shuffled_indices = np.random.permutation(self.size)
        return self.buffer['context'][shuffled_indices], self.buffer['reward'][shuffled_indices]


class NeuralBanditUCB:
    def __init__(
            self,
            input_dim,
            segments_per_prompt,
            beta=3,
            lamb=1,
            hidden_size=16,
            learning_rate=0.001,
            weight_decay=0.0,
            steps_per_update=100,
            device='cpu',
            **kwargs):
        self.d = input_dim
        self.segments_per_prompt = segments_per_prompt
        self.beta = beta
        self.hidden_size = hidden_size
        self.steps_per_update = steps_per_update

        self.segment_buffers = []
        self.segment_reward_models = []
        self.initial_weights = []
        self.optimizers = []
        self.sigma_invs = []
        self.numels = []
        for _ in range(segments_per_prompt):
            reward_model = MLP(input_dim=input_dim, hidden_dim=hidden_size, output_dim=1, num_hidden_layers=1).to(device)
            initial_weights = {k: v.clone().detach() for k, v in reward_model.state_dict().items()}
            optimizer = optim.Adam(reward_model.parameters(), lr=learning_rate, weight_decay=weight_decay)
            self.segment_reward_models.append(reward_model)
            self.initial_weights.append(initial_weights)
            self.optimizers.append(optimizer)
            self.segment_buffers.append(ReplayBuffer(input_dim, 10000))

            numel = sum(w.numel() for w in reward_model.parameters() if w.requires_grad)
            sigma_inv = lamb * np.eye(numel, dtype=np.float32)
            self.sigma_invs.append(sigma_inv)
            self.numels.append(numel)

        self.device = device

        self.loss_hist = []
        self.reward_pred_hist = []
        self.ucb_pred_hist = []
        self.epsilon = []

    def take_action(self, all_segments):
        selected_segment_idxs = []

        self.epsilon = []
        for seg_idx in range(self.segments_per_prompt):
            context = torch.tensor(all_segments, dtype=torch.float32)
            context = context.to(self.device)
            context.requires_grad = True
            num_segments = len(all_segments)

            g = np.zeros((num_segments, self.numels[seg_idx]), dtype=np.float32)
            for k in range(num_segments):
                g[k] = self.grad(context[k].unsqueeze(0), segment_idx=seg_idx).cpu().numpy()

            with torch.no_grad():
                reward_model_term = self.segment_reward_models[seg_idx](context).cpu().numpy()
                ucb_term = np.sqrt(np.matmul(np.matmul(g[:, None, :], self.sigma_invs[seg_idx]), g[:, :, None])[:, 0, :])
                p = reward_model_term + self.beta * ucb_term

                self.epsilon.append(ucb_term.mean())

                self.ucb_pred_hist.append(ucb_term)
                self.reward_pred_hist.append(reward_model_term)

            action = np.argmax(p)
            selected_segment_idxs.append(action)

        return selected_segment_idxs

    def grad(self, x, segment_idx):
        y = self.segment_reward_models[segment_idx](x)
        self.optimizers[segment_idx].zero_grad()
        y.backward()
        return torch.cat(
            [w.grad.detach().flatten() / np.sqrt(self.hidden_size) for w in self.segment_reward_models[segment_idx].parameters() if w.requires_grad]
        ).to(self.device)

    def store_data(self, context, reward):
        assert len(context) == self.segments_per_prompt
        for seg_idx in range(self.segments_per_prompt):

            to_add = torch.tensor(context[seg_idx], dtype=torch.float32)
            to_add = to_add.to(self.device).squeeze()
            assert to_add.shape == (self.d,)

            to_add_unsqueeze = to_add.unsqueeze(0)
            assert to_add_unsqueeze.shape == (1, self.d), to_add_unsqueeze.shape

            grad = self.grad(to_add_unsqueeze, seg_idx)
            assert grad.shape == (self.numels[seg_idx],), grad.shape

            grad_unsqueeze = grad[:, None]
            assert grad_unsqueeze.shape == (self.numels[seg_idx], 1), grad_unsqueeze

            self.sherman_morrison_update(grad_unsqueeze.cpu().numpy(), seg_idx)
            self.segment_buffers[seg_idx].add(to_add.cpu().numpy(), reward)

    def sherman_morrison_update(self, v, segment_idx):
        self.sigma_invs[segment_idx] -= (self.sigma_invs[segment_idx] @ v @ v.T @ self.sigma_invs[segment_idx]) / (1 + v.T @ self.sigma_invs[segment_idx] @ v)

    def train(self):
        all_segment_model_losses = [[] for _ in range(self.segments_per_prompt)]

        for seg_idx in range(self.segments_per_prompt):
            self.segment_reward_models[seg_idx].load_state_dict({k: v.clone() for k, v in self.initial_weights[seg_idx].items()})

            for _ in range(self.steps_per_update):
                if len(self.segment_buffers[seg_idx].buffer['context']) < 2:
                    continue
                x = self.segment_buffers[seg_idx].buffer['context']
                y = self.segment_buffers[seg_idx].buffer['reward']

                p = np.random.permutation(len(x))
                x = x[p]
                y = y[p]

                x = torch.tensor(x, dtype=torch.float32).to(self.device)
                y = torch.tensor(y, dtype=torch.float32).to(self.device).view(-1, 1)
                y_hat = self.segment_reward_models[seg_idx](x)
                loss = F.mse_loss(y_hat, y)
                self.optimizers[seg_idx].zero_grad()
                loss.backward()
                self.optimizers[seg_idx].step()
                all_segment_model_losses[seg_idx].append(loss.item())

        return np.array(all_segment_model_losses)[:, -10:].mean(axis=1)


class NeuralBanditEpsGreedy:
    def __init__(self,
                 input_dim,
                 segments_per_prompt,
                 segment_length,
                 weight_decay=0.0,
                 learning_rate=0.01,
                 steps_per_update=100,
                 epsilon="decay",
                 num_rollouts=100,
                 **kwargs
                 ):
        self.n_segments = segments_per_prompt
        self.segment_length = segment_length
        self.lr = learning_rate
        self.wd = weight_decay

        self.segment_buffers = []
        self.segment_reward_models = []
        self.initial_weights = []
        self.optimizers = []
        self.input_dim = input_dim
        for _ in range(segments_per_prompt):
            reward_model = MLP(input_dim, 1, hidden_dim=16, num_hidden_layers=1)

            initial_weights = {k: v.clone().detach() for k, v in reward_model.state_dict().items()}
            optimizer = optim.SGD(reward_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

            self.initial_weights.append(initial_weights)
            self.optimizers.append(optimizer)
            self.segment_reward_models.append(reward_model)
            self.segment_buffers.append(ReplayBuffer(input_dim, 10000))

        self.learning_rate = learning_rate
        self.steps_per_update = steps_per_update

        if epsilon == "decay":
            self.epsilon = 1.0
            self.epsilon_decay = 1.0 / (num_rollouts * 0.2)
        else:
            self.epsilon = float(epsilon)
            self.epsilon_decay = 0.0

        self.loss_fn = nn.MSELoss()

    def store_data(self, segments, reward):
        """Store observed context and reward."""
        assert len(segments) == self.n_segments
        for segment_idx in range(self.n_segments):
            self.segment_buffers[segment_idx].add(segments[segment_idx], reward)

    def train(self):
        """Train the neural network using the observed data."""

        all_segment_model_losses = [[] for _ in range(self.n_segments)]
        for segment_idx in range(self.n_segments):

            self.segment_reward_models[segment_idx].load_state_dict(self.initial_weights[segment_idx])
            self.segment_reward_models[segment_idx].train()

            losses = []
            for step in range(self.steps_per_update):

                contexts, rewards = self.segment_buffers[segment_idx].permute()
                contexts_tensor = torch.from_numpy(contexts).float()
                rewards_tensor = torch.from_numpy(rewards).float()

                self.optimizers[segment_idx].zero_grad()

                predictions = self.segment_reward_models[segment_idx](contexts_tensor)
                loss = self.loss_fn(predictions.squeeze(), rewards_tensor.squeeze())
                losses.append(loss.item())

                loss.backward()

                torch.nn.utils.clip_grad_norm_(self.segment_reward_models[segment_idx].parameters(), 10.0)
                self.optimizers[segment_idx].step()

            all_segment_model_losses[segment_idx].extend(losses)

        return np.array(all_segment_model_losses).mean(axis=1)

    def predict(self, context):
        """Predict the reward for a given context."""
        segment_preds = []
        context_tensor = torch.from_numpy(np.array(context)).to(torch.float32)
        with torch.no_grad():
            for segment_idx in range(self.n_segments):
                reward_pred = self.segment_reward_models[segment_idx](context_tensor).squeeze()
                segment_preds.append(reward_pred.detach().cpu().numpy())

        return segment_preds

    def take_action(self, context):
        with torch.no_grad():
            selected_segment_idxs = []
            for s_idx in range(self.n_segments):
                if np.random.rand() < self.epsilon:
                    selected_segment_idx = np.random.randint(len(context))
                else:
                    reward_model = self.segment_reward_models[s_idx]
                    reward_model.eval()

                    segment_reward_preds = reward_model(torch.from_numpy(np.array(context)).to(torch.float32))
                    segment_reward_preds = segment_reward_preds.detach().cpu().numpy()
                    selected_segment_idx = np.argmax(segment_reward_preds)

                selected_segment_idxs.append(selected_segment_idx)

        self.epsilon = max(0.0, self.epsilon - self.epsilon_decay)

        return selected_segment_idxs


class ThompsonSamplingBandit:
    def __init__(
            self,
            input_dim,
            segments_per_prompt,
            prior_variance=1.0,
            noise_variance=1.0,
            **kwargs
    ):
        self.segments_per_prompt = segments_per_prompt

        self.mus = [np.zeros(input_dim) for _ in range(segments_per_prompt)]  # Prior mean of θ
        self.Sigmas = [np.eye(input_dim) * prior_variance  for _ in range(segments_per_prompt)]  # Prior covariance of θ
        self.noise_variance = noise_variance  # Known observation noise σ²

        self.epsilon_hist = []
        self.epsilon = -1

    def take_action(self, all_segments):
        """Sample θ from posterior, compute predicted reward for current context."""
        segment_idxs = []
        for j in range(self.segments_per_prompt):
            theta_sample = np.random.multivariate_normal(self.mus[j], self.Sigmas[j])
            preds = []
            for segment in all_segments:
                predicted_mean = np.dot(theta_sample, segment)
                preds.append(predicted_mean)
            segment_idx = np.argmax(preds)
            segment_idxs.append(segment_idx)

        return segment_idxs

    def store_data(self, context, reward):
        """Update the posterior given (context, reward)."""
        assert len(context) == self.segments_per_prompt
        for j in range(self.segments_per_prompt):

            x = np.array(context[j]).reshape(-1, 1)
            Sigma_inv = np.linalg.inv(self.Sigmas[j])
            Sigma_post_inv = Sigma_inv + (1 / self.noise_variance) * (x @ x.T)
            self.Sigmas[j] = np.linalg.inv(Sigma_post_inv)
            self.mus[j] = self.Sigmas[j] @ (
                Sigma_inv @ self.mus[j] + (1 / self.noise_variance) * reward * np.array(context[j])
            )

    def train(self):
        return None



