'''
Synthetic contextual bandits
'''

import torch
import random
import numpy as np
from pathlib import Path

'''
Linear bandit
'''
class LinearBandit(object):
    def __init__(self, theta, sigma=1.0):
        self.theta = theta
        self.sigma = sigma

    def get_reward(self, X, arm):
        ''' 
        Args: 
            - X: contextual feature vector with shape (num_arm, dim_context)
            - arm: which arm to pull, int
        Return: 
            - reward: scalar
            - regret: expected regret

        '''
        prod = X @ self.theta
        regret = prod.max() - prod[arm]
        reward = prod[arm] + self.sigma * torch.randn(1, device=X.device)
        return reward, regret



'''
Quadratic bandit
'''
class QuadBandit(object):
    def __init__(self, theta, sigma=1.0):
        self.theta = theta
        self.sigma = sigma

    def get_reward(self, X, arm):
        ''' 
        Args: 
            - X: contextual feature vector with shape (num_arm, dim_context)
            - arm: which arm to pull, int
        Return: 
            - reward: scalar
            - regret: expected regret

        '''
        prod = X @ self.theta
        h = np.exp(-10 * prod ** 2)
        reward = h[arm] + self.sigma * torch.randn(1, device=X.device)
        regret = h.max() - h[arm]
        return reward, regret


'''
Logistic bandit
'''
class LogisticBandit(object):
    def __init__(self, theta, sigma=1.0):
        self.theta = theta
        self.sigma = sigma

    def get_reward(self, X, arm):
        '''
        Args:
            - X: contextual feature vector with shape (num_arm, dim_context)
            - arm: which arm to pull, int
        Return:
            - reward: scalar
            - regret: expected regret

        '''
        prod = X @ self.theta
        probs = torch.sigmoid(prod)
        h = torch.bernoulli(probs)
        reward = h[arm]
        regret = probs.max() - probs[arm]
        return reward, regret



'''
Distance bandit
'''
class DistBandit(object):
    def __init__(self, theta, sigma=1.0):
        self.theta = theta
        self.sigma = sigma

    @torch.no_grad()
    def get_reward(self, X, arm):
        ''' 
        Args: 
            - X: contextual feature vector with shape (num_arm, dim_context)
            - arm: which arm to pull, int
        Return: 
            - reward: scalar
            - regret: expected regret
        '''

        diffs = X - self.theta
        h = -1 * torch.norm(diffs, dim=1)

        reward = h[arm] + self.sigma * torch.randn(1, device=X.device)
        regret = h.max() - h[arm]
        return reward, regret
    


'''
Quad Form bandit
'''
class QuadFormBandit(object):

    def __init__(self, A, sigma=1.0, device="cpu", precompute=True):
        """
        Args:
            - A: shape (d, d). Fixed after creation/loading.
            - sigma: float, std dev of Gaussian noise
        """
        self.A = A.to(device)
        self.sigma = float(sigma)
        self.device = device
        self.M = (self.A @ self.A.T) if precompute else None

    @torch.no_grad()
    def get_reward(self, X, arm):
        X = X.to(self.device)

        if self.M is not None:
            means = 0.01 * (X @ self.M * X).sum(dim=1)
        else:
            AX = X @ self.A
            means = (AX * AX).sum(dim=1)

        noise = self.sigma * torch.randn((), device=self.device)
        reward = means[arm] + noise
        regret = means.max() - means[arm]
        return reward, regret

def generate_and_save_A(dim_context, out_path="A.pt", seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    A = torch.randn(dim_context, dim_context)
    out_path = Path(out_path)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    torch.save(A, out_path.as_posix())
    return out_path.as_posix()

def load_A(path, map_location="cpu"):
    return torch.load(path, map_location=map_location)
