# Copyright (C) king.com Ltd 2025
# License: Apache 2.0

import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim


class RewardModelUCB(nn.Module):
    def __init__(self, input_size, use_custom_init=False):
        super().__init__()
        self.hidden_size = 16
        self.net = nn.Sequential(
            nn.Linear(input_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, 1)
        )
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.net.to(self.device)

        # UCB custom init
        if use_custom_init:
            self.custom_initialize()

    def custom_initialize(self):
        m = self.hidden_size
        linear_layers = [layer for layer in self.net if isinstance(layer, nn.Linear)]

        for layer in linear_layers[:-1]:
            nn.init.normal_(layer.weight, mean=0.0, std=(4 / m) ** 0.5)
            nn.init.zeros_(layer.bias)

        final_layer = linear_layers[-1]
        w = torch.randn((m,), device=self.device) * (2 / m) ** 0.5
        final_layer.weight.data = w.unsqueeze(0)
        nn.init.zeros_(final_layer.bias)

    def forward(self, x):
        return self.net(x)


class StandardRewardModel(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.net.to(self.device)

    def forward(self, x):
        return self.net(x)


class EpsGreedyBandit(nn.Module):
    def __init__(self, epsilon=0.1, lr=0.01, update_on_buffer=True, reset_param=False, num_segments=3, num_choices_per_segment=11, updates_per_sample=1, buffer_capacity=1000, batch_size=64):
        super().__init__()

        self.epsilon = epsilon
        self.loss_fn = nn.MSELoss()
        self.update_on_buffer = update_on_buffer
        self.reset_param = reset_param
        self.batch_size = batch_size

        self.num_segments = num_segments
        self.num_choices_per_segment = num_choices_per_segment
        self.updates_per_sample = updates_per_sample

        self.model = StandardRewardModel(input_size=num_segments)
        self.initial_state_dict = self.model.state_dict()
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        self.buffer_capacity = buffer_capacity
        self.buffer_actions = torch.zeros((self.buffer_capacity, self.num_segments), dtype=torch.float32,
                                          device=self.device)
        self.buffer_rewards = torch.zeros((self.buffer_capacity, 1), dtype=torch.float32, device=self.device)
        self.buffer_index = 0
        self.buffer_full = False

        base = torch.arange(1, num_choices_per_segment + 1, 1)
        grids = torch.meshgrid(*([base] * num_segments), indexing='ij')
        self.all_flat_actions = torch.stack(grids, dim=-1).float().reshape(-1, num_segments).to(self.device)

        self.name = "Standard eps-greedy bandit"

    def select_action(self):
        if random.random() < self.epsilon:
            idx = random.randint(0, self.all_flat_actions.shape[0] - 1)
        else:
            preds = self.model(self.all_flat_actions)
            idx = torch.argmax(preds).item()
        return self.all_flat_actions[idx]

    def update(self, full_action, reward):

        self.buffer_actions[self.buffer_index] = full_action
        self.buffer_rewards[self.buffer_index] = reward
        self.buffer_index += 1
        if self.buffer_index >= self.buffer_capacity:
            self.buffer_index = 0
            self.buffer_full = True

        all_losses = [[]]
        if self.update_on_buffer:
            if self.reset_param:
                self.model.load_state_dict(self.initial_state_dict)

            max_size = self.buffer_capacity if self.buffer_full else self.buffer_index
            for _ in range(self.updates_per_sample):
                idxs = torch.randint(0, max_size, (min(self.batch_size, max_size),), device=self.device)
                x = self.buffer_actions[idxs]
                y = self.buffer_rewards[idxs]

                pred = self.model(x)
                loss = self.loss_fn(pred, y)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                all_losses[0].append(loss.item())
        else:
            x = full_action.reshape(1, -1)
            y = torch.tensor([[reward]], dtype=torch.float32, device=self.device)
            pred = self.model(x)
            loss = self.loss_fn(pred, y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            all_losses[0].append(loss.item())

        return all_losses


class FactoredEpsGreedyBandit(nn.Module):
    def __init__(self, epsilon=0.1, lr=0.01, update_on_buffer=True, reset_param=False, num_segments=3, num_choices_per_segment=11, updates_per_sample=1, buffer_capacity=1000, batch_size=64):

        super().__init__()

        self.epsilon = epsilon
        self.loss_fn = nn.MSELoss()
        self.update_on_buffer = update_on_buffer
        self.reset_param = reset_param
        self.batch_size = batch_size

        self.num_segments = num_segments
        self.num_choices_per_segment = num_choices_per_segment
        self.updates_per_sample = updates_per_sample
        self.counter = 1

        self.models = []
        self.optimizers = []
        self.buffers = []
        self.rewards = []
        self.indices = []
        self.filled_flags = []
        self.initial_dicts = []
        self.all_seen_actions = []
        self.all_seen_rewards = []

        self.buffer_capacity = buffer_capacity
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        for _ in range(num_segments):
            model = StandardRewardModel(input_size=1)
            optimizer = optim.Adam(model.parameters(), lr=lr)
            self.models.append(model)
            self.optimizers.append(optimizer)
            self.initial_dicts.append(model.state_dict())

            self.buffers.append(torch.zeros((self.buffer_capacity, 1), dtype=torch.float32, device=self.device))
            self.rewards.append(torch.zeros((self.buffer_capacity, 1), dtype=torch.float32, device=self.device))
            self.indices.append(0)
            self.filled_flags.append(False)

        self.all_segment_actions = torch.arange(1, num_choices_per_segment + 1, 1).float().to(self.device)
        self.name = "Factored eps-greedy bandit"

    def select_action(self):
        full_action = []
        epsilon = 1 / self.counter if self.epsilon == "linear" else float(self.epsilon)

        for i in range(self.num_segments):
            if random.random() < epsilon:
                idx = random.randint(0, self.all_segment_actions.shape[0] - 1)
            else:
                preds = self.models[i](self.all_segment_actions.view(-1, 1))
                idx = torch.argmax(preds).item()
            full_action.append(self.all_segment_actions[idx])

        self.counter += 1
        return torch.stack(full_action).to(self.device)

    def update(self, full_action, reward):
        self.all_seen_actions.append(full_action.detach().cpu().numpy().reshape(self.num_segments))
        self.all_seen_rewards.append(reward)

        all_losses = [[] for _ in range(self.num_segments)]
        for s in range(self.num_segments):
            i = self.indices[s]
            self.buffers[s][i] = full_action[s]
            self.rewards[s][i] = reward
            self.indices[s] += 1
            if self.indices[s] >= self.buffer_capacity:
                self.indices[s] = 0
                self.filled_flags[s] = True

            if self.update_on_buffer:
                if self.reset_param:
                    self.models[s].load_state_dict(self.initial_dicts[s])

                size = self.buffer_capacity if self.filled_flags[s] else self.indices[s]
                for _ in range(self.updates_per_sample):
                    idxs = torch.randint(0, size, (min(self.batch_size, size),), device=self.device)
                    x = self.buffers[s][idxs]
                    y = self.rewards[s][idxs]
                    pred = self.models[s](x)
                    loss = self.loss_fn(pred, y)
                    self.optimizers[s].zero_grad()
                    loss.backward()
                    self.optimizers[s].step()
                    all_losses[s].append(loss.item())
            else:
                x = full_action[s].reshape(1, 1)
                y = torch.tensor([[reward]], dtype=torch.float32, device=self.device)
                pred = self.models[s](x)
                loss = self.loss_fn(pred, y)
                self.optimizers[s].zero_grad()
                loss.backward()
                self.optimizers[s].step()
                all_losses[s].append(loss.item())

        return all_losses


class FactoredUCBBandit(nn.Module):
    def __init__(self, epsilon=0.1, lr=0.01, update_on_buffer=True, reset_param=False, num_segments=3,
                 num_choices_per_segment=11, updates_per_sample=1, buffer_capacity=1000, batch_size=64, beta=1.0,
                 **kwargs):
        super().__init__()
        self.context_dim = num_choices_per_segment

        self.epsilon = epsilon
        self.loss_fn = nn.MSELoss()
        self.update_on_buffer = update_on_buffer
        self.reset_param = reset_param
        self.batch_size = batch_size
        self.lr = lr

        self.num_segments = num_segments
        self.num_choices_per_segment = num_choices_per_segment
        self.updates_per_sample = updates_per_sample
        self.buffer_capacity = buffer_capacity
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.Ts = []
        lamb = 1
        self.all_segment_actions = torch.arange(1, num_choices_per_segment + 1, 1).float().to(self.device)
        self.K = self.all_segment_actions.shape[0]
        self.beta = beta 
        self.ucb_pred_hist = []
        self.reward_pred_hist = []
        self.expl_coef_hist = []

        self.models = []
        self.optimizers = []
        self.buffers = []
        self.rewards = []
        self.indices = []
        self.filled_flags = []
        self.initial_dicts = []
        self.all_seen_actions = []
        self.all_seen_rewards = []
        self.all_sigma_invs = []
        self.all_theta0s = []

        self.sigma_reset_every = 10

        for _ in range(num_segments):
            model = RewardModelUCB(input_size=self.context_dim, use_custom_init=True).to(self.device)
            self.numel = sum(w.numel() for w in model.parameters() if w.requires_grad)
            optimizer = optim.Adam(model.parameters(), lr=lr)
            self.models.append(model)
            self.optimizers.append(optimizer)
            self.initial_dicts.append(model.state_dict())

            self.buffers.append(torch.zeros((self.buffer_capacity, 1), dtype=torch.float32, device=self.device))
            self.rewards.append(torch.zeros((self.buffer_capacity, 1), dtype=torch.float32, device=self.device))
            self.indices.append(0)
            self.filled_flags.append(False)

            sigma_inv = lamb * np.eye(self.numel, dtype=np.float32)
            self.all_sigma_invs.append(sigma_inv)
            theta0 = torch.cat(
                [w.flatten() for w in model.parameters() if w.requires_grad]
            )
            self.all_theta0s.append(theta0)
            self.Ts.append(0)
            self.reward_pred_hist.append([])
            self.ucb_pred_hist.append([])

        self.name = "Factored UCB bandit"
        
    def action_to_one_hot(self, segment_action):
        one_hot = np.zeros(self.num_choices_per_segment)
        one_hot[int(segment_action - 1)] = 1

        return one_hot

    def batch_action_to_one_hot(self, segment_actions):
        batch_size = segment_actions.size(0)
        num_choices = self.num_choices_per_segment

        indices = segment_actions.long() - 1

        one_hot = torch.zeros(batch_size, num_choices, device=segment_actions.device)
        one_hot[torch.arange(batch_size), indices.squeeze()] = 1.0

        assert one_hot.shape == (batch_size, num_choices), one_hot.shape
        assert torch.all(one_hot.sum(dim=1) == 1), one_hot.sum(dim=1)

        return one_hot

    def sherman_morrison_update(self, v, s_idx):
        sigma_inv = self.all_sigma_invs[s_idx]
        self.all_sigma_invs[s_idx] -= ((sigma_inv @ v @ v.T @ sigma_inv) / (1 + v.T @ sigma_inv @ v)) * 0.5

    def grad(self, x, s_idx):
        y = self.models[s_idx](x)
        self.optimizers[s_idx].zero_grad()
        y.backward()
        g = torch.cat(
            [w.grad.detach().flatten() for w in self.models[s_idx].parameters() if w.requires_grad]
        ).to(self.device)

        return g

    def select_action(self):
        full_action = []

        for s_idx in range(self.num_segments):

            context = self.all_segment_actions.clone()
            context = context.to(self.device)
            context.requires_grad = True
            g = np.zeros((self.K, self.numel), dtype=np.float32)
            for k in range(self.K):
                a_onehot = torch.from_numpy(self.action_to_one_hot(context[k])).float().to(self.device)
                g[k] = self.grad(a_onehot.unsqueeze(0), s_idx=s_idx).cpu().numpy()

            with torch.no_grad():
                actions_one_hot = self.batch_action_to_one_hot(context).view(-1, self.context_dim)
                reward_model_term = self.models[s_idx](actions_one_hot).cpu().numpy()

                ucb_term = np.sqrt(np.matmul(np.matmul(g[:, None, :], self.all_sigma_invs[s_idx]),
                                             g[:, :, None] / self.models[s_idx].hidden_size)[:, 0, :])

                expl_coef = self.beta * np.sqrt(np.log(self.Ts[s_idx] + 1))
                if s_idx == 0:
                    self.expl_coef_hist.append(expl_coef)

                p = reward_model_term + expl_coef * ucb_term

                self.ucb_pred_hist[s_idx].append(ucb_term)
                self.reward_pred_hist[s_idx].append(reward_model_term)

            a_idx = np.argmax(p)
            full_action.append(self.all_segment_actions[a_idx])

        return torch.stack(full_action).to(self.device)

    def update(self, full_action, reward):
        self.all_seen_actions.append(full_action.detach().cpu().numpy().reshape(self.num_segments))
        self.all_seen_rewards.append(reward)

        all_losses = [[] for _ in range(self.num_segments)]
        for s in range(self.num_segments):
            i = self.indices[s]
            self.buffers[s][i] = full_action[s]
            self.rewards[s][i] = reward
            self.indices[s] += 1
            if self.indices[s] >= self.buffer_capacity:
                self.indices[s] = 0
                self.filled_flags[s] = True

            a_onehot = torch.from_numpy(self.action_to_one_hot(full_action[s])).float().to(self.device)
            grad = self.grad(a_onehot.unsqueeze(0), s_idx=s)
            assert grad.shape == (self.numel,), grad.shape
            grad_unsqueeze = grad[:, None]
            assert grad_unsqueeze.shape == (self.numel, 1), grad_unsqueeze

            self.sherman_morrison_update(grad_unsqueeze.cpu().numpy(), s_idx=s)
            self.Ts[s] += 1

            if self.update_on_buffer:
                if self.reset_param:
                    self.models[s].load_state_dict(self.initial_dicts[s])
                    self.optimizers[s] = optim.Adam(self.models[s].parameters(), lr=self.lr)

                size = self.buffer_capacity if self.filled_flags[s] else self.indices[s]
                for _ in range(self.updates_per_sample):
                    idxs = torch.randint(0, size, (min(self.batch_size, size),), device=self.device)
                    x = self.buffers[s][idxs]
                    x = self.batch_action_to_one_hot(x)
                    y = self.rewards[s][idxs]
                    pred = self.models[s](x)
                    loss = self.loss_fn(pred, y)

                    loss += 1e-4 * torch.norm(
                        torch.cat([w.flatten() for w in self.models[s].parameters() if w.requires_grad]) -
                        self.all_theta0s[s]) ** 2
                    self.optimizers[s].zero_grad()
                    loss.backward()
                    self.optimizers[s].step()
                    all_losses[s].append(loss.item())
            else:
                x = full_action[s].reshape(1, 1)
                y = torch.tensor([[reward]], dtype=torch.float32, device=self.device)
                pred = self.models[s](x)
                loss = self.loss_fn(pred, y)
                self.optimizers[s].zero_grad()
                loss.backward()
                self.optimizers[s].step()
                all_losses[s].append(loss.item())

        return all_losses


class UCBBandit(nn.Module):
    def __init__(self, epsilon=0.1, lr=0.01, update_on_buffer=True, reset_param=False, num_segments=3,
                 num_choices_per_segment=11, updates_per_sample=1, buffer_capacity=1000, batch_size=64, beta=1.0,
                 **kwargs):
        super().__init__()

        self.epsilon = epsilon
        self.loss_fn = nn.MSELoss()
        self.update_on_buffer = update_on_buffer
        self.reset_param = reset_param
        self.batch_size = batch_size

        self.num_segments = num_segments
        self.num_choices_per_segment = num_choices_per_segment
        self.updates_per_sample = updates_per_sample

        self.context_dim = num_segments * num_choices_per_segment

        self.model = RewardModelUCB(input_size=self.context_dim, use_custom_init=True)
        self.initial_state_dict = self.model.state_dict()
        self.lr = lr
        self.optimizer = optim.Adam(self.model.parameters(), lr=lr)
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        self.buffer_capacity = buffer_capacity
        self.buffer_actions = torch.zeros((self.buffer_capacity, self.num_segments), dtype=torch.float32,
                                          device=self.device)
        self.buffer_rewards = torch.zeros((self.buffer_capacity, 1), dtype=torch.float32, device=self.device)
        self.buffer_index = 0
        self.buffer_full = False

        base = torch.arange(1, num_choices_per_segment + 1, 1)
        grids = torch.meshgrid(*([base] * num_segments), indexing='ij')
        self.all_flat_actions = torch.stack(grids, dim=-1).float().reshape(-1, num_segments).to(self.device)

        self.name = "Standard UCB bandit"

        self.T = 0
        lamb = 1
        self.numel = sum(w.numel() for w in self.model.parameters() if w.requires_grad)
        self.sigma_inv = lamb * np.eye(self.numel, dtype=np.float32)
        self.theta0 = torch.cat(
            [w.flatten() for w in self.model.parameters() if w.requires_grad]
        )
        self.K = self.all_flat_actions.shape[0]
        self.beta = beta

        self.ucb_pred_hist = []
        self.reward_pred_hist = []
        self.expl_coef_hist = []

        self.all_seen_actions = []
        self.all_seen_rewards = []

    def action_to_one_hot(self, full_action):
        one_hot = np.zeros(self.num_segments * self.num_choices_per_segment)
        for seg_idx in range(self.num_segments):
            choice_idx = int(full_action[seg_idx]) - 1
            pos = seg_idx * self.num_choices_per_segment + choice_idx
            one_hot[pos] = 1
        return one_hot

    def batch_action_to_one_hot(self, full_actions):
        """
        full_actions: tensor of shape (batch_size, num_segments)
        Returns:
            one_hot: tensor of shape (batch_size, num_segments * num_choices_per_segment)
        """
        batch_size = full_actions.size(0)
        num_segments = self.num_segments
        num_choices = self.num_choices_per_segment

        choice_indices = full_actions.long() - 1

        segment_offsets = torch.arange(num_segments, device=full_actions.device) * num_choices
        flat_indices = choice_indices + segment_offsets

        flat_indices = flat_indices.reshape(-1)

        one_hot = torch.zeros(batch_size, num_segments * num_choices, device=full_actions.device)
        one_hot.view(-1, num_segments * num_choices)[
            torch.arange(batch_size).repeat_interleave(num_segments),
            flat_indices
        ] = 1.0

        return one_hot

    def sherman_morrison_update(self, v):
        self.sigma_inv -= ((self.sigma_inv @ v @ v.T @ self.sigma_inv) / (1+v.T @ self.sigma_inv @ v)) * 0.5

    def grad(self, x):
        y = self.model(x)
        self.optimizer.zero_grad()
        y.backward()
        return torch.cat(
            [w.grad.detach().flatten() for w in self.model.parameters() if w.requires_grad]
        ).to(self.device)

    def select_action(self):
        context = self.all_flat_actions.clone()
        context = context.to(self.device)
        context.requires_grad = True
        g = np.zeros((self.K, self.numel), dtype=np.float32)
        for k in range(self.K):
            a_onehot = torch.from_numpy(self.action_to_one_hot(context[k])).float().to(self.device)
            g[k] = self.grad(a_onehot.unsqueeze(0)).cpu().numpy()

        with torch.no_grad():
            reward_model_term = self.model(self.batch_action_to_one_hot(context)).detach().cpu().numpy()

            ucb_term = np.sqrt(np.matmul(np.matmul(g[:, None, :], self.sigma_inv), g[:, :, None] / self.model.hidden_size)[:, 0, :])

            expl_coef = self.beta * np.sqrt(np.log(self.T + 1))
            self.expl_coef_hist.append(expl_coef)

            p = reward_model_term + expl_coef * ucb_term

            self.ucb_pred_hist.append(ucb_term)
            self.reward_pred_hist.append(reward_model_term)

        idx = np.argmax(p)
        return self.all_flat_actions[idx]

    def update(self, full_action, reward):
        self.all_seen_actions.append(full_action.detach().cpu().numpy().reshape(self.num_segments))
        self.all_seen_rewards.append(reward)

        self.buffer_actions[self.buffer_index] = full_action
        self.buffer_rewards[self.buffer_index] = reward
        self.buffer_index += 1
        if self.buffer_index >= self.buffer_capacity:
            self.buffer_index = 0
            self.buffer_full = True

        a_onehot = torch.from_numpy(self.action_to_one_hot(full_action)).float().to(self.device)
        grad = self.grad(a_onehot.unsqueeze(0))
        assert grad.shape == (self.numel,), grad.shape
        grad_unsqueeze = grad[:, None]
        assert grad_unsqueeze.shape == (self.numel, 1), grad_unsqueeze

        self.sherman_morrison_update(grad_unsqueeze.cpu().numpy())
        self.T += 1

        all_losses = [[]]
        if self.update_on_buffer:
            if self.reset_param:
                self.model.load_state_dict(self.initial_state_dict)
                self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)

            max_size = self.buffer_capacity if self.buffer_full else self.buffer_index
            for _ in range(self.updates_per_sample):
                idxs = torch.randint(0, max_size, (min(self.batch_size, max_size),), device=self.device)
                x = self.buffer_actions[idxs]
                x = self.batch_action_to_one_hot(x)
                y = self.buffer_rewards[idxs]

                pred = self.model(x)
                loss = self.loss_fn(pred, y)
                loss += 1e-4 * torch.norm(torch.cat([w.flatten() for w in self.model.parameters() if w.requires_grad]) - self.theta0) ** 2
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                all_losses[0].append(loss.item())
        else:
            x = full_action.reshape(1, -1)
            y = torch.tensor([[reward]], dtype=torch.float32, device=self.device)
            pred = self.model(x)
            loss = self.loss_fn(pred, y)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            all_losses[0].append(loss.item())

        return all_losses


class TSBandit:
    def __init__(self, context_dim, num_choices_per_segment, num_segments, prior_variance=10.0, noise_variance=1.1, **kwargs):
        context_dim = num_segments * num_choices_per_segment

        self.mu = np.zeros(context_dim)
        self.Sigma = np.eye(context_dim) * prior_variance
        self.noise_variance = noise_variance

        self.context_dim = context_dim
        self.num_segments = num_segments
        self.num_choices_per_segment = num_choices_per_segment
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        base = torch.arange(1, num_choices_per_segment + 1, 1)
        grids = torch.meshgrid(*([base] * num_segments), indexing='ij')
        self.all_flat_actions = torch.stack(grids, dim=-1).float().reshape(-1, num_segments).to(self.device)
        self.all_flat_actions_np = self.all_flat_actions.cpu().numpy()

        self.name = "Standard Thompson Sampling bandit"
        self.all_seen_actions = []
        self.all_seen_rewards = []

    def action_to_one_hot(self, full_action):
        one_hot = np.zeros(self.num_segments * self.num_choices_per_segment)
        for seg_idx in range(self.num_segments):
            choice_idx = int(full_action[seg_idx]) - 1
            pos = seg_idx * self.num_choices_per_segment + choice_idx
            one_hot[pos] = 1
        return one_hot

    def select_action(self):
        """Sample θ from posterior, compute predicted reward for current context."""
        theta_sample = np.random.multivariate_normal(self.mu, self.Sigma)
        preds = []
        for a in self.all_flat_actions_np:
            a_one_hot = self.action_to_one_hot(a)
            pred = np.dot(theta_sample, a_one_hot)
            preds.append(pred)

        predicted_mean = np.array(preds)

        idx = np.argmax(predicted_mean)
        full_action = self.all_flat_actions[idx]

        return full_action

    def update(self, full_action, reward):
        """Update the posterior given (context, reward)."""
        self.all_seen_actions.append(full_action.detach().cpu().numpy().reshape(self.num_segments))
        self.all_seen_rewards.append(reward)

        x = self.action_to_one_hot(full_action.detach().cpu().numpy())
        x = x.reshape(-1, 1)

        Sigma_inv = np.linalg.inv(self.Sigma)
        Sigma_post_inv = Sigma_inv + (1 / self.noise_variance) * (x @ x.T)

        self.Sigma = np.linalg.inv(Sigma_post_inv)
        self.mu = self.Sigma @ (Sigma_inv @ self.mu + (1 / self.noise_variance) * reward * self.action_to_one_hot(full_action.detach().cpu().numpy()))


class FactoredTSBandit:
    def __init__(self, num_choices_per_segment, num_segments, context_dim=1, prior_variance=10.0, noise_variance=0.1, **kwargs):
        context_dim = num_choices_per_segment

        self.mus = []
        self.sigmas = []
        self.noise_variances = []
        for s_idx in range(num_segments):
            mu = np.zeros(context_dim)
            Sigma = np.eye(context_dim) * prior_variance
            noise_variance = noise_variance
            self.mus.append(mu)
            self.sigmas.append(Sigma)
            self.noise_variances.append(noise_variance)

        self.context_dim = context_dim
        self.num_segments = num_segments
        self.num_choices_per_segment = num_choices_per_segment
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.all_segment_actions = torch.arange(1, num_choices_per_segment + 1, 1).float().to(self.device)
        self.all_segment_actions_np = self.all_segment_actions.cpu().numpy()

        self.name = "Factored Thompson Sampling Bandit"
        self.all_seen_actions = []
        self.all_seen_rewards = []

    def action_to_one_hot(self, segment_action):
        one_hot = np.zeros(self.num_choices_per_segment)
        one_hot[int(segment_action - 1)] = 1

        return one_hot

    def select_action(self):
        """Sample θ from posterior, compute predicted reward for current context."""
        full_action = []
        for s_idx in range(self.num_segments):
            theta_sample = np.random.multivariate_normal(self.mus[s_idx], self.sigmas[s_idx])
            preds = []
            for a in self.all_segment_actions_np:
                a_one_hot = self.action_to_one_hot(a)
                pred = np.dot(theta_sample, a_one_hot)
                preds.append(pred)

            predicted_mean = np.array(preds)

            idx = np.argmax(predicted_mean)
            full_action.append(self.all_segment_actions[idx])

        return torch.stack(full_action).to(self.device)

    def update(self, full_action, reward):
        """Update the posterior given (context, reward)."""
        self.all_seen_actions.append(full_action.detach().cpu().numpy().reshape(self.num_segments))
        self.all_seen_rewards.append(reward)

        for s_idx in range(self.num_segments):
            x = self.action_to_one_hot(full_action[s_idx].detach().cpu().numpy())
            x = x.reshape(-1, 1)

            Sigma_inv = np.linalg.inv(self.sigmas[s_idx])
            Sigma_post_inv = Sigma_inv + (1 / self.noise_variances[s_idx]) * (x @ x.T)

            self.sigmas[s_idx] = np.linalg.inv(Sigma_post_inv)
            self.mus[s_idx] = self.sigmas[s_idx] @ (Sigma_inv @ self.mus[s_idx] + (1 / self.noise_variances[s_idx]) * reward * self.action_to_one_hot(full_action[s_idx].detach().cpu().numpy()))

